update
This commit is contained in:
parent
ee853333a8
commit
1773929431
180
app/backtest/decision_env.py
Normal file
180
app/backtest/decision_env.py
Normal file
@ -0,0 +1,180 @@
|
||||
"""Reinforcement-learning style environment wrapping the backtest engine."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
import math
|
||||
|
||||
from .engine import BacktestEngine, BacktestResult, BtConfig
|
||||
from app.agents.game import Decision
|
||||
from app.agents.registry import weight_map
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "decision_env"}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ParameterSpec:
|
||||
"""Defines how a scalar action dimension maps to strategy parameters."""
|
||||
|
||||
name: str
|
||||
target: str
|
||||
minimum: float = 0.0
|
||||
maximum: float = 1.0
|
||||
|
||||
def clamp(self, value: float) -> float:
|
||||
clipped = max(0.0, min(1.0, float(value)))
|
||||
return self.minimum + clipped * (self.maximum - self.minimum)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeMetrics:
|
||||
total_return: float
|
||||
max_drawdown: float
|
||||
volatility: float
|
||||
nav_series: List[Dict[str, float]]
|
||||
trades: List[Dict[str, object]]
|
||||
|
||||
@property
|
||||
def sharpe_like(self) -> float:
|
||||
if self.volatility <= 1e-9:
|
||||
return 0.0
|
||||
return self.total_return / self.volatility
|
||||
|
||||
|
||||
class DecisionEnv:
|
||||
"""Thin RL-friendly wrapper that evaluates parameter actions via backtest."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bt_config: BtConfig,
|
||||
parameter_specs: Sequence[ParameterSpec],
|
||||
baseline_weights: Mapping[str, float],
|
||||
reward_fn: Optional[Callable[[EpisodeMetrics], float]] = None,
|
||||
disable_departments: bool = False,
|
||||
) -> None:
|
||||
self._template_cfg = bt_config
|
||||
self._specs = list(parameter_specs)
|
||||
self._baseline_weights = dict(baseline_weights)
|
||||
self._reward_fn = reward_fn or self._default_reward
|
||||
self._last_metrics: Optional[EpisodeMetrics] = None
|
||||
self._last_action: Optional[Tuple[float, ...]] = None
|
||||
self._episode = 0
|
||||
self._disable_departments = bool(disable_departments)
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
return len(self._specs)
|
||||
|
||||
def reset(self) -> Dict[str, float]:
|
||||
self._episode += 1
|
||||
self._last_metrics = None
|
||||
self._last_action = None
|
||||
return {
|
||||
"episode": float(self._episode),
|
||||
"baseline_return": 0.0,
|
||||
}
|
||||
|
||||
def step(self, action: Sequence[float]) -> Tuple[Dict[str, float], float, bool, Dict[str, object]]:
|
||||
if len(action) != self.action_dim:
|
||||
raise ValueError(f"expected action length {self.action_dim}, got {len(action)}")
|
||||
action_array = [float(val) for val in action]
|
||||
self._last_action = tuple(action_array)
|
||||
|
||||
weights = self._build_weights(action_array)
|
||||
LOGGER.info("episode=%s action=%s weights=%s", self._episode, action_array, weights, extra=LOG_EXTRA)
|
||||
|
||||
cfg = replace(self._template_cfg)
|
||||
engine = BacktestEngine(cfg)
|
||||
engine.weights = weight_map(weights)
|
||||
if self._disable_departments:
|
||||
engine.department_manager = None
|
||||
|
||||
try:
|
||||
result = engine.run()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("backtest failed under action", extra={**LOG_EXTRA, "error": str(exc)})
|
||||
info = {"error": str(exc)}
|
||||
return {"failure": 1.0}, -1.0, True, info
|
||||
|
||||
metrics = self._compute_metrics(result)
|
||||
reward = float(self._reward_fn(metrics))
|
||||
self._last_metrics = metrics
|
||||
|
||||
observation = {
|
||||
"total_return": metrics.total_return,
|
||||
"max_drawdown": metrics.max_drawdown,
|
||||
"volatility": metrics.volatility,
|
||||
"sharpe_like": metrics.sharpe_like,
|
||||
}
|
||||
info = {
|
||||
"nav_series": metrics.nav_series,
|
||||
"trades": metrics.trades,
|
||||
"weights": weights,
|
||||
}
|
||||
return observation, reward, True, info
|
||||
|
||||
def _build_weights(self, action: Sequence[float]) -> Dict[str, float]:
|
||||
weights = dict(self._baseline_weights)
|
||||
for idx, spec in enumerate(self._specs):
|
||||
value = spec.clamp(action[idx])
|
||||
if spec.target.startswith("agent_weights."):
|
||||
agent_name = spec.target.split(".", 1)[1]
|
||||
weights[agent_name] = value
|
||||
else:
|
||||
LOGGER.debug("暂未支持的参数目标:%s", spec.target, extra=LOG_EXTRA)
|
||||
return weights
|
||||
|
||||
def _compute_metrics(self, result: BacktestResult) -> EpisodeMetrics:
|
||||
nav_series = result.nav_series or []
|
||||
if not nav_series:
|
||||
return EpisodeMetrics(0.0, 0.0, 0.0, [], result.trades)
|
||||
|
||||
nav_values = [row.get("nav", 0.0) for row in nav_series]
|
||||
if not nav_values or nav_values[0] == 0:
|
||||
base_nav = nav_values[0] if nav_values else 1.0
|
||||
else:
|
||||
base_nav = nav_values[0]
|
||||
|
||||
returns = [(nav / base_nav) - 1.0 for nav in nav_values]
|
||||
total_return = returns[-1]
|
||||
|
||||
peak = nav_values[0]
|
||||
max_drawdown = 0.0
|
||||
for nav in nav_values:
|
||||
if nav > peak:
|
||||
peak = nav
|
||||
drawdown = (peak - nav) / peak if peak else 0.0
|
||||
max_drawdown = max(max_drawdown, drawdown)
|
||||
|
||||
diffs = [nav_values[idx] - nav_values[idx - 1] for idx in range(1, len(nav_values))]
|
||||
if diffs:
|
||||
mean_diff = sum(diffs) / len(diffs)
|
||||
variance = sum((diff - mean_diff) ** 2 for diff in diffs) / len(diffs)
|
||||
volatility = math.sqrt(variance) / base_nav
|
||||
else:
|
||||
volatility = 0.0
|
||||
|
||||
return EpisodeMetrics(
|
||||
total_return=float(total_return),
|
||||
max_drawdown=float(max_drawdown),
|
||||
volatility=volatility,
|
||||
nav_series=nav_series,
|
||||
trades=result.trades,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _default_reward(metrics: EpisodeMetrics) -> float:
|
||||
penalty = 0.5 * metrics.max_drawdown
|
||||
return metrics.total_return - penalty
|
||||
|
||||
@property
|
||||
def last_metrics(self) -> Optional[EpisodeMetrics]:
|
||||
return self._last_metrics
|
||||
|
||||
@property
|
||||
def last_action(self) -> Optional[Tuple[float, ...]]:
|
||||
return self._last_action
|
||||
@ -423,6 +423,18 @@ SCHEMA_STATEMENTS: Iterable[str] = (
|
||||
notes TEXT,
|
||||
metadata TEXT
|
||||
);
|
||||
""",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS tuning_results (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
experiment_id TEXT,
|
||||
strategy TEXT,
|
||||
action TEXT,
|
||||
weights TEXT,
|
||||
reward REAL,
|
||||
metrics TEXT,
|
||||
created_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
@ -456,6 +468,7 @@ REQUIRED_TABLES = (
|
||||
"portfolio_positions",
|
||||
"portfolio_trades",
|
||||
"portfolio_snapshots",
|
||||
"tuning_results",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -13,6 +13,8 @@ if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
@ -24,6 +26,7 @@ import streamlit as st
|
||||
from app.agents.base import AgentContext
|
||||
from app.agents.game import Decision
|
||||
from app.backtest.engine import BtConfig, run_backtest
|
||||
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
||||
from app.data.schema import initialize_database
|
||||
from app.ingest.checker import run_boot_check
|
||||
from app.ingest.tushare import FetchJob, run_ingestion
|
||||
@ -53,6 +56,8 @@ from app.utils.portfolio import (
|
||||
list_positions,
|
||||
list_recent_trades,
|
||||
)
|
||||
from app.agents.registry import default_agents
|
||||
from app.utils.tuning import log_tuning_result
|
||||
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
@ -623,6 +628,7 @@ def render_backtest() -> None:
|
||||
st.header("回测与复盘")
|
||||
st.write("在此运行回测、展示净值曲线与代理贡献。")
|
||||
|
||||
cfg = get_config()
|
||||
default_start, default_end = _default_backtest_range(window_days=60)
|
||||
LOGGER.debug(
|
||||
"回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
|
||||
@ -746,6 +752,347 @@ def render_backtest() -> None:
|
||||
status_box.update(label="回测执行失败", state="error")
|
||||
st.error(f"回测执行失败:{exc}")
|
||||
|
||||
with st.expander("离线调参实验 (DecisionEnv)", expanded=False):
|
||||
st.caption(
|
||||
"使用 DecisionEnv 对代理权重做离线调参。请选择需要优化的代理并设定权重范围,"
|
||||
"系统会运行一次回测并返回收益、回撤等指标。若 LLM 网络不可用,将返回失败标记。"
|
||||
)
|
||||
|
||||
disable_departments = st.checkbox(
|
||||
"禁用部门 LLM(仅规则代理,适合离线快速评估)",
|
||||
value=True,
|
||||
help="关闭部门调用后不依赖外部 LLM 网络,仅根据规则代理权重模拟。",
|
||||
)
|
||||
|
||||
default_experiment_id = f"streamlit_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
experiment_id = st.text_input(
|
||||
"实验 ID",
|
||||
value=default_experiment_id,
|
||||
help="用于在 tuning_results 表中区分不同实验。",
|
||||
)
|
||||
strategy_label = st.text_input(
|
||||
"策略说明",
|
||||
value="DecisionEnv",
|
||||
help="可选:为本次调参记录一个策略名称或备注。",
|
||||
)
|
||||
|
||||
agent_objects = default_agents()
|
||||
agent_names = [agent.name for agent in agent_objects]
|
||||
if not agent_names:
|
||||
st.info("暂无可调整的代理。")
|
||||
else:
|
||||
selected_agents = st.multiselect(
|
||||
"选择调参的代理权重",
|
||||
agent_names,
|
||||
default=agent_names[:2],
|
||||
key="decision_env_agents",
|
||||
)
|
||||
|
||||
specs: List[ParameterSpec] = []
|
||||
action_values: List[float] = []
|
||||
range_valid = True
|
||||
for idx, agent_name in enumerate(selected_agents):
|
||||
col_min, col_max, col_action = st.columns([1, 1, 2])
|
||||
min_key = f"decision_env_min_{agent_name}"
|
||||
max_key = f"decision_env_max_{agent_name}"
|
||||
action_key = f"decision_env_action_{agent_name}"
|
||||
default_min = 0.0
|
||||
default_max = 1.0
|
||||
min_val = col_min.number_input(
|
||||
f"{agent_name} 最小权重",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=default_min,
|
||||
step=0.05,
|
||||
key=min_key,
|
||||
)
|
||||
max_val = col_max.number_input(
|
||||
f"{agent_name} 最大权重",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=default_max,
|
||||
step=0.05,
|
||||
key=max_key,
|
||||
)
|
||||
if max_val <= min_val:
|
||||
range_valid = False
|
||||
action_val = col_action.slider(
|
||||
f"{agent_name} 动作 (0-1)",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.5,
|
||||
step=0.01,
|
||||
key=action_key,
|
||||
)
|
||||
specs.append(
|
||||
ParameterSpec(
|
||||
name=f"weight_{agent_name}",
|
||||
target=f"agent_weights.{agent_name}",
|
||||
minimum=min_val,
|
||||
maximum=max_val,
|
||||
)
|
||||
)
|
||||
action_values.append(action_val)
|
||||
|
||||
run_decision_env = st.button("执行单次调参", key="run_decision_env_button")
|
||||
if run_decision_env:
|
||||
if not selected_agents:
|
||||
st.warning("请至少选择一个代理进行调参。")
|
||||
elif not range_valid:
|
||||
st.error("请确保所有代理的最大权重大于最小权重。")
|
||||
else:
|
||||
baseline_weights = cfg.agent_weights.as_dict()
|
||||
for agent in agent_objects:
|
||||
baseline_weights.setdefault(agent.name, 1.0)
|
||||
|
||||
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||
if not universe_env:
|
||||
st.error("请先指定至少一个股票代码。")
|
||||
else:
|
||||
bt_cfg_env = BtConfig(
|
||||
id="decision_env_streamlit",
|
||||
name="DecisionEnv Streamlit",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
universe=universe_env,
|
||||
params={
|
||||
"target": target,
|
||||
"stop": stop,
|
||||
"hold_days": int(hold_days),
|
||||
},
|
||||
method=cfg.decision_method,
|
||||
)
|
||||
env = DecisionEnv(
|
||||
bt_config=bt_cfg_env,
|
||||
parameter_specs=specs,
|
||||
baseline_weights=baseline_weights,
|
||||
disable_departments=disable_departments,
|
||||
)
|
||||
env.reset()
|
||||
with st.spinner("正在执行离线调参……"):
|
||||
try:
|
||||
observation, reward, done, info = env.step(action_values)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA)
|
||||
st.error(f"离线调参失败:{exc}")
|
||||
else:
|
||||
if observation.get("failure"):
|
||||
st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。")
|
||||
st.json(observation)
|
||||
else:
|
||||
st.success("离线调参完成")
|
||||
col_metrics = st.columns(4)
|
||||
col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}")
|
||||
col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}")
|
||||
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
|
||||
col_metrics[3].metric("奖励", f"{reward:+.4f}")
|
||||
|
||||
st.write("调参后权重:")
|
||||
weights_dict = info.get("weights", {})
|
||||
st.json(weights_dict)
|
||||
action_payload = {
|
||||
name: value
|
||||
for name, value in zip(selected_agents, action_values)
|
||||
}
|
||||
metrics_payload = dict(observation)
|
||||
metrics_payload["reward"] = reward
|
||||
try:
|
||||
log_tuning_result(
|
||||
experiment_id=experiment_id or str(uuid.uuid4()),
|
||||
strategy=strategy_label or "DecisionEnv",
|
||||
action=action_payload,
|
||||
reward=reward,
|
||||
metrics=metrics_payload,
|
||||
weights=weights_dict,
|
||||
)
|
||||
st.caption("调参结果已写入 tuning_results 表。")
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
|
||||
|
||||
if weights_dict:
|
||||
if st.button(
|
||||
"保存这些权重为默认配置",
|
||||
key="save_decision_env_weights_single",
|
||||
):
|
||||
cfg.agent_weights.update_from_dict(weights_dict)
|
||||
save_config(cfg)
|
||||
st.success("代理权重已写入 config.json")
|
||||
|
||||
nav_series = info.get("nav_series")
|
||||
if nav_series:
|
||||
try:
|
||||
nav_df = pd.DataFrame(nav_series)
|
||||
if {"trade_date", "nav"}.issubset(nav_df.columns):
|
||||
nav_df = nav_df.sort_values("trade_date")
|
||||
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"])
|
||||
st.line_chart(nav_df.set_index("trade_date")["nav"], height=220)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA)
|
||||
trades = info.get("trades")
|
||||
if trades:
|
||||
st.write("成交记录:")
|
||||
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
|
||||
|
||||
st.divider()
|
||||
st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。")
|
||||
default_grid = "\n".join(
|
||||
[
|
||||
",".join(["0.2" for _ in specs]),
|
||||
",".join(["0.5" for _ in specs]),
|
||||
",".join(["0.8" for _ in specs]),
|
||||
]
|
||||
) if specs else ""
|
||||
action_grid_raw = st.text_area(
|
||||
"动作列表",
|
||||
value=default_grid,
|
||||
height=120,
|
||||
key="decision_env_batch_actions",
|
||||
)
|
||||
run_batch = st.button("批量执行调参", key="run_decision_env_batch")
|
||||
if run_batch:
|
||||
if not selected_agents:
|
||||
st.warning("请先选择调参代理。")
|
||||
elif not range_valid:
|
||||
st.error("请确保所有代理的最大权重大于最小权重。")
|
||||
else:
|
||||
lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()]
|
||||
if not lines:
|
||||
st.warning("请在文本框中输入至少一组动作。")
|
||||
else:
|
||||
parsed_actions: List[List[float]] = []
|
||||
for line in lines:
|
||||
try:
|
||||
values = [float(val.strip()) for val in line.split(',') if val.strip()]
|
||||
except ValueError:
|
||||
st.error(f"无法解析动作行:{line}")
|
||||
parsed_actions = []
|
||||
break
|
||||
if len(values) != len(specs):
|
||||
st.error(f"动作维度不匹配(期望 {len(specs)} 个值):{line}")
|
||||
parsed_actions = []
|
||||
break
|
||||
parsed_actions.append(values)
|
||||
if parsed_actions:
|
||||
baseline_weights = cfg.agent_weights.as_dict()
|
||||
for agent in agent_objects:
|
||||
baseline_weights.setdefault(agent.name, 1.0)
|
||||
|
||||
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||
if not universe_env:
|
||||
st.error("请先指定至少一个股票代码。")
|
||||
else:
|
||||
bt_cfg_env = BtConfig(
|
||||
id="decision_env_streamlit_batch",
|
||||
name="DecisionEnv Batch",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
universe=universe_env,
|
||||
params={
|
||||
"target": target,
|
||||
"stop": stop,
|
||||
"hold_days": int(hold_days),
|
||||
},
|
||||
method=cfg.decision_method,
|
||||
)
|
||||
env = DecisionEnv(
|
||||
bt_config=bt_cfg_env,
|
||||
parameter_specs=specs,
|
||||
baseline_weights=baseline_weights,
|
||||
disable_departments=disable_departments,
|
||||
)
|
||||
results: List[Dict[str, object]] = []
|
||||
with st.spinner("正在批量执行调参……"):
|
||||
for idx, action_vals in enumerate(parsed_actions, start=1):
|
||||
env.reset()
|
||||
try:
|
||||
observation, reward, done, info = env.step(action_vals)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("批量调参失败", extra=LOG_EXTRA)
|
||||
results.append(
|
||||
{
|
||||
"序号": idx,
|
||||
"动作": action_vals,
|
||||
"状态": "error",
|
||||
"错误": str(exc),
|
||||
}
|
||||
)
|
||||
continue
|
||||
if observation.get("failure"):
|
||||
results.append(
|
||||
{
|
||||
"序号": idx,
|
||||
"动作": action_vals,
|
||||
"状态": "failure",
|
||||
"奖励": -1.0,
|
||||
}
|
||||
)
|
||||
else:
|
||||
action_payload = {
|
||||
name: value
|
||||
for name, value in zip(selected_agents, action_vals)
|
||||
}
|
||||
metrics_payload = dict(observation)
|
||||
metrics_payload["reward"] = reward
|
||||
weights_payload = info.get("weights", {})
|
||||
try:
|
||||
log_tuning_result(
|
||||
experiment_id=experiment_id or str(uuid.uuid4()),
|
||||
strategy=strategy_label or "DecisionEnv",
|
||||
action=action_payload,
|
||||
reward=reward,
|
||||
metrics=metrics_payload,
|
||||
weights=weights_payload,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
|
||||
results.append(
|
||||
{
|
||||
"序号": idx,
|
||||
"动作": action_vals,
|
||||
"状态": "ok",
|
||||
"总收益": observation.get("total_return", 0.0),
|
||||
"最大回撤": observation.get("max_drawdown", 0.0),
|
||||
"波动率": observation.get("volatility", 0.0),
|
||||
"奖励": reward,
|
||||
"权重": weights_payload,
|
||||
}
|
||||
)
|
||||
if results:
|
||||
st.write("批量调参结果:")
|
||||
results_df = pd.DataFrame(results)
|
||||
st.dataframe(results_df, hide_index=True, width='stretch')
|
||||
selectable = [
|
||||
row
|
||||
for row in results
|
||||
if row.get("状态") == "ok" and row.get("权重")
|
||||
]
|
||||
if selectable:
|
||||
option_labels = [
|
||||
f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}"
|
||||
for row in selectable
|
||||
]
|
||||
selected_label = st.selectbox(
|
||||
"选择要保存的记录",
|
||||
option_labels,
|
||||
key="decision_env_batch_select",
|
||||
)
|
||||
selected_row = None
|
||||
for label, row in zip(option_labels, selectable):
|
||||
if label == selected_label:
|
||||
selected_row = row
|
||||
break
|
||||
if selected_row and st.button(
|
||||
"保存所选权重为默认配置",
|
||||
key="save_decision_env_weights_batch",
|
||||
):
|
||||
cfg.agent_weights.update_from_dict(selected_row.get("权重", {}))
|
||||
save_config(cfg)
|
||||
st.success(
|
||||
f"已将序号 {selected_row['序号']} 的权重写入 config.json"
|
||||
)
|
||||
else:
|
||||
st.caption("暂无成功的结果可供保存。")
|
||||
|
||||
|
||||
def render_settings() -> None:
|
||||
LOGGER.info("渲染设置页面", extra=LOG_EXTRA)
|
||||
|
||||
@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
from typing import Dict, Iterable, List, Mapping, Optional
|
||||
|
||||
|
||||
def _default_root() -> Path:
|
||||
@ -48,6 +48,32 @@ class AgentWeights:
|
||||
"A_macro": self.macro,
|
||||
}
|
||||
|
||||
def update_from_dict(self, data: Mapping[str, float]) -> None:
|
||||
mapping = {
|
||||
"A_mom": "momentum",
|
||||
"momentum": "momentum",
|
||||
"A_val": "value",
|
||||
"value": "value",
|
||||
"A_news": "news",
|
||||
"news": "news",
|
||||
"A_liq": "liquidity",
|
||||
"liquidity": "liquidity",
|
||||
"A_macro": "macro",
|
||||
"macro": "macro",
|
||||
}
|
||||
for key, attr in mapping.items():
|
||||
if key in data and data[key] is not None:
|
||||
try:
|
||||
setattr(self, attr, float(data[key]))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, float]) -> "AgentWeights":
|
||||
inst = cls()
|
||||
inst.update_from_dict(data)
|
||||
return inst
|
||||
|
||||
DEFAULT_LLM_MODEL_OPTIONS: Dict[str, Dict[str, object]] = {
|
||||
"ollama": {
|
||||
"models": ["llama3", "phi3", "qwen2"],
|
||||
@ -357,6 +383,10 @@ def _load_from_file(cfg: AppConfig) -> None:
|
||||
if "decision_method" in payload:
|
||||
cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method)
|
||||
|
||||
weights_payload = payload.get("agent_weights")
|
||||
if isinstance(weights_payload, dict):
|
||||
cfg.agent_weights.update_from_dict(weights_payload)
|
||||
|
||||
legacy_profiles: Dict[str, Dict[str, object]] = {}
|
||||
legacy_routes: Dict[str, Dict[str, object]] = {}
|
||||
|
||||
@ -523,6 +553,7 @@ def save_config(cfg: AppConfig | None = None) -> None:
|
||||
"tushare_token": cfg.tushare_token,
|
||||
"force_refresh": cfg.force_refresh,
|
||||
"decision_method": cfg.decision_method,
|
||||
"agent_weights": cfg.agent_weights.as_dict(),
|
||||
"llm": {
|
||||
"strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||
"majority_threshold": cfg.llm.majority_threshold,
|
||||
|
||||
42
app/utils/tuning.py
Normal file
42
app/utils/tuning.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""Helpers for logging decision tuning experiments."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .db import db_session
|
||||
from .logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "tuning"}
|
||||
|
||||
|
||||
def log_tuning_result(
|
||||
*,
|
||||
experiment_id: str,
|
||||
strategy: str,
|
||||
action: Dict[str, Any],
|
||||
reward: float,
|
||||
metrics: Dict[str, Any],
|
||||
weights: Optional[Dict[str, float]] = None,
|
||||
) -> None:
|
||||
"""Persist a tuning result into the SQLite table."""
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO tuning_results (experiment_id, strategy, action, weights, reward, metrics)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
experiment_id,
|
||||
strategy,
|
||||
json.dumps(action, ensure_ascii=False),
|
||||
json.dumps(weights or {}, ensure_ascii=False),
|
||||
float(reward),
|
||||
json.dumps(metrics, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
|
||||
@ -36,8 +36,11 @@
|
||||
- Streamlit 侧边栏监听 `llm.metrics` 的实时事件,并以 ~0.75 秒节流频率刷新“系统监控”,既保证日志到达后快速更新,也避免刷屏造成 UI 闪烁。
|
||||
- 新增投资管理数据层:SQLite 中创建 `investment_pool`、`portfolio_positions`、`portfolio_trades`、`portfolio_snapshots` 四张表;`app/utils/portfolio.py` 提供访问接口,今日计划页可实时展示候选池、持仓与成交。
|
||||
- 回测引擎 `record_agent_state()` 现同步写入 `investment_pool`,将每日全局决策的置信度、部门标签与目标权重落库,作为后续提示参数调优与候选池管理的基础数据。
|
||||
- `app/backtest/decision_env.py` 引入 `DecisionEnv`,用单步 RL/Gym 风格接口封装回测:动作 → 权重映射 → 回测 → 奖励(收益 - 0.5×回撤),同时输出 NAV、交易与行动权重,方便与 Bandit/PPO 等算法对接。
|
||||
- Streamlit “回测与复盘” 页新增离线调参模块,可即点即用 DecisionEnv 对代理权重进行实验,并可视化收益、回撤、成交与权重结果,支持一键写入 `config.json` 成为新的默认权重。
|
||||
- 所有离线调参实验(单次/批量)都会存入 SQLite `tuning_results`,包含实验 ID、动作、奖励、指标与权重,便于后续分析与对比。
|
||||
|
||||
## 下一阶段路线图
|
||||
- 将 `BacktestEngine` 封装为 `DecisionEnv`,让一次策略配置跑完整个回测周期并输出奖励、约束违例等指标。
|
||||
- 接入 Bandit/贝叶斯优化,对 Prompt 版本、部门权重、温度范围做离线搜索,利用新增的 snapshot/positions 数据衡量风险与收益。
|
||||
- 构建持仓/成交写入流程(回测与实时),确保 RL 训练能复原资金曲线、资金占用与调仓成本。
|
||||
- 在 `DecisionEnv` 中扩展动作映射(Prompt 版本、部门温度、function 调用策略等),把当前权重型动作升级为多参数协同调整。
|
||||
- 接入 Bandit/贝叶斯优化,对动作空间进行探索,并把 `portfolio_snapshots`、`portfolio_trades` 输出纳入奖励约束(收益、回撤、换手率)。
|
||||
- 构建持仓/成交写入流程的实时入口,使线上监控与离线调参共用同一数据源,支撑增量训练与策略回放。
|
||||
|
||||
52
scripts/run_decision_env_example.py
Normal file
52
scripts/run_decision_env_example.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""Quick example of using DecisionEnv for weight tuning experiments."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import date, timedelta
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
||||
from app.backtest.engine import BtConfig
|
||||
from app.agents.registry import default_agents
|
||||
from app.utils.config import get_config
|
||||
|
||||
|
||||
def main() -> None:
|
||||
cfg = get_config()
|
||||
agents = default_agents()
|
||||
baseline_weights = {agent.name: cfg.agent_weights.as_dict().get(agent.name, 1.0) for agent in agents}
|
||||
|
||||
today = date.today()
|
||||
bt_cfg = BtConfig(
|
||||
id="decision_env_example",
|
||||
name="Decision Env Demo",
|
||||
start_date=today - timedelta(days=60),
|
||||
end_date=today,
|
||||
universe=["000001.SZ"],
|
||||
params={},
|
||||
method=cfg.decision_method,
|
||||
)
|
||||
|
||||
specs = [
|
||||
ParameterSpec(name="momentum_weight", target="agent_weights.A_mom", minimum=0.1, maximum=0.6),
|
||||
ParameterSpec(name="value_weight", target="agent_weights.A_val", minimum=0.1, maximum=0.4),
|
||||
]
|
||||
|
||||
env = DecisionEnv(bt_config=bt_cfg, parameter_specs=specs, baseline_weights=baseline_weights)
|
||||
env.reset()
|
||||
observation, reward, done, info = env.step([0.5, 0.2])
|
||||
|
||||
print("Observation:", json.dumps(observation, ensure_ascii=False, indent=2))
|
||||
print("Reward:", reward)
|
||||
print("Done:", done)
|
||||
print("Weights:", json.dumps(info.get("weights", {}), ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user