791 lines
44 KiB
Python
791 lines
44 KiB
Python
"""回测与复盘相关视图。"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import uuid
|
||
from dataclasses import asdict
|
||
from datetime import date
|
||
from typing import Dict, List, Optional
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import plotly.express as px
|
||
import requests
|
||
from requests.exceptions import RequestException
|
||
import streamlit as st
|
||
|
||
from app.agents.base import AgentContext
|
||
from app.agents.game import Decision
|
||
from app.agents.registry import default_agents
|
||
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
||
from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig, run_backtest
|
||
from app.ingest.checker import run_boot_check
|
||
from app.ingest.tushare import run_ingestion
|
||
from app.llm.client import run_llm
|
||
from app.llm.metrics import reset as reset_llm_metrics
|
||
from app.llm.metrics import snapshot as snapshot_llm_metrics
|
||
from app.utils import alerts
|
||
from app.utils.config import get_config, save_config
|
||
from app.utils.tuning import log_tuning_result
|
||
|
||
from app.ui.shared import LOGGER, LOG_EXTRA, default_backtest_range
|
||
from app.ui.views.dashboard import update_dashboard_sidebar
|
||
|
||
_DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result"
|
||
_DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
|
||
|
||
def render_backtest_review() -> None:
|
||
"""渲染回测执行、调参与结果复盘页面。"""
|
||
st.header("回测与复盘")
|
||
st.caption("1. 基于历史数据复盘当前策略;2. 借助强化学习/调参探索更优参数组合。")
|
||
app_cfg = get_config()
|
||
default_start, default_end = default_backtest_range(window_days=60)
|
||
LOGGER.debug(
|
||
"回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
|
||
default_start,
|
||
default_end,
|
||
"000001.SZ",
|
||
0.035,
|
||
-0.015,
|
||
10,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
|
||
st.markdown("### 回测参数")
|
||
col1, col2 = st.columns(2)
|
||
start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date")
|
||
end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date")
|
||
universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe")
|
||
col_target, col_stop, col_hold = st.columns(3)
|
||
target = col_target.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f", key="bt_target")
|
||
stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f", key="bt_stop")
|
||
hold_days = col_hold.number_input("持有期(交易日)", value=10, step=1, key="bt_hold_days")
|
||
LOGGER.debug(
|
||
"当前回测表单输入:start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s",
|
||
start_date,
|
||
end_date,
|
||
universe_text,
|
||
target,
|
||
stop,
|
||
hold_days,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
|
||
tab_backtest, tab_rl = st.tabs(["回测验证", "强化学习调参"])
|
||
|
||
with tab_backtest:
|
||
st.markdown("#### 回测执行")
|
||
if st.button("运行回测", key="bt_run_button"):
|
||
LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA)
|
||
decision_log_container = st.container()
|
||
status_box = st.status("准备执行回测...", expanded=True)
|
||
llm_stats_placeholder = st.empty()
|
||
decision_entries: List[str] = []
|
||
|
||
def _decision_callback(ts_code: str, trade_dt: date, ctx: AgentContext, decision: Decision) -> None:
|
||
ts_label = trade_dt.isoformat()
|
||
summary = ""
|
||
for dept_decision in decision.department_decisions.values():
|
||
if getattr(dept_decision, "summary", ""):
|
||
summary = str(dept_decision.summary)
|
||
break
|
||
entry_lines = [
|
||
f"**{ts_label} {ts_code}** → {decision.action.value} (信心 {decision.confidence:.2f})",
|
||
]
|
||
if summary:
|
||
entry_lines.append(f"摘要:{summary}")
|
||
dep_highlights = []
|
||
for dept_code, dept_decision in decision.department_decisions.items():
|
||
dep_highlights.append(
|
||
f"{dept_code}:{dept_decision.action.value}({dept_decision.confidence:.2f})"
|
||
)
|
||
if dep_highlights:
|
||
entry_lines.append("部门意见:" + ";".join(dep_highlights))
|
||
decision_entries.append(" \n".join(entry_lines))
|
||
decision_log_container.markdown("\n\n".join(decision_entries[-200:]))
|
||
status_box.write(f"{ts_label} {ts_code} → {decision.action.value} (信心 {decision.confidence:.2f})")
|
||
stats = snapshot_llm_metrics()
|
||
llm_stats_placeholder.json(
|
||
{
|
||
"LLM 调用次数": stats.get("total_calls", 0),
|
||
"Prompt Tokens": stats.get("total_prompt_tokens", 0),
|
||
"Completion Tokens": stats.get("total_completion_tokens", 0),
|
||
"按 Provider": stats.get("provider_calls", {}),
|
||
"按模型": stats.get("model_calls", {}),
|
||
}
|
||
)
|
||
update_dashboard_sidebar(stats)
|
||
|
||
reset_llm_metrics()
|
||
status_box.update(label="执行回测中...", state="running")
|
||
try:
|
||
universe = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||
LOGGER.info(
|
||
"回测参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
|
||
start_date,
|
||
end_date,
|
||
universe,
|
||
target,
|
||
stop,
|
||
hold_days,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
backtest_cfg = BtConfig(
|
||
id="streamlit_demo",
|
||
name="Streamlit Demo Strategy",
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
universe=universe,
|
||
params={
|
||
"target": target,
|
||
"stop": stop,
|
||
"hold_days": int(hold_days),
|
||
},
|
||
)
|
||
result = run_backtest(backtest_cfg, decision_callback=_decision_callback)
|
||
LOGGER.info(
|
||
"回测完成:nav_records=%s trades=%s",
|
||
len(result.nav_series),
|
||
len(result.trades),
|
||
extra=LOG_EXTRA,
|
||
)
|
||
status_box.update(label="回测执行完成", state="complete")
|
||
st.success("回测执行完成,详见下方结果与统计。")
|
||
metrics = snapshot_llm_metrics()
|
||
llm_stats_placeholder.json(
|
||
{
|
||
"LLM 调用次数": metrics.get("total_calls", 0),
|
||
"Prompt Tokens": metrics.get("total_prompt_tokens", 0),
|
||
"Completion Tokens": metrics.get("total_completion_tokens", 0),
|
||
"按 Provider": metrics.get("provider_calls", {}),
|
||
"按模型": metrics.get("model_calls", {}),
|
||
}
|
||
)
|
||
update_dashboard_sidebar(metrics)
|
||
st.session_state["backtest_last_result"] = {"nav_records": result.nav_series, "trades": result.trades}
|
||
st.json(st.session_state["backtest_last_result"])
|
||
except Exception as exc: # noqa: BLE001
|
||
LOGGER.exception("回测执行失败", extra=LOG_EXTRA)
|
||
status_box.update(label="回测执行失败", state="error")
|
||
st.error(f"回测执行失败:{exc}")
|
||
|
||
last_result = st.session_state.get("backtest_last_result")
|
||
if last_result:
|
||
st.markdown("#### 最近回测输出")
|
||
st.json(last_result)
|
||
|
||
st.divider()
|
||
# ADD: Comparison view for multiple backtest configurations
|
||
with st.expander("历史回测结果对比", expanded=False):
|
||
st.caption("从历史回测配置中选择多个进行净值曲线与指标对比。")
|
||
normalize_to_one = st.checkbox("归一化到 1 起点", value=True, key="bt_cmp_normalize")
|
||
use_log_y = st.checkbox("对数坐标", value=False, key="bt_cmp_log_y")
|
||
metric_options = ["总收益", "最大回撤", "交易数", "平均换手", "风险事件"]
|
||
selected_metrics = st.multiselect("显示指标列", metric_options, default=metric_options, key="bt_cmp_metrics")
|
||
try:
|
||
with db_session(read_only=True) as conn:
|
||
cfg_rows = conn.execute(
|
||
"SELECT id, name FROM bt_config ORDER BY rowid DESC LIMIT 50"
|
||
).fetchall()
|
||
except Exception: # noqa: BLE001
|
||
LOGGER.exception("读取 bt_config 失败", extra=LOG_EXTRA)
|
||
cfg_rows = []
|
||
cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows]
|
||
selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2], key="bt_cmp_configs")
|
||
selected_ids = [label.split(" | ")[0].strip() for label in selected_labels]
|
||
nav_df = pd.DataFrame()
|
||
rpt_df = pd.DataFrame()
|
||
if selected_ids:
|
||
try:
|
||
with db_session(read_only=True) as conn:
|
||
nav_df = pd.read_sql_query(
|
||
"SELECT cfg_id, trade_date, nav FROM bt_nav WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
|
||
conn,
|
||
params=tuple(selected_ids),
|
||
)
|
||
rpt_df = pd.read_sql_query(
|
||
"SELECT cfg_id, summary FROM bt_report WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
|
||
conn,
|
||
params=tuple(selected_ids),
|
||
)
|
||
except Exception: # noqa: BLE001
|
||
LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA)
|
||
st.error("读取回测结果失败")
|
||
nav_df = pd.DataFrame()
|
||
rpt_df = pd.DataFrame()
|
||
if not nav_df.empty:
|
||
try:
|
||
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce")
|
||
# ADD: date window filter
|
||
overall_min = pd.to_datetime(nav_df["trade_date"].min()).date()
|
||
overall_max = pd.to_datetime(nav_df["trade_date"].max()).date()
|
||
col_d1, col_d2 = st.columns(2)
|
||
start_filter = col_d1.date_input("起始日期", value=overall_min, key="bt_cmp_start")
|
||
end_filter = col_d2.date_input("结束日期", value=overall_max, key="bt_cmp_end")
|
||
if start_filter > end_filter:
|
||
start_filter, end_filter = end_filter, start_filter
|
||
mask = (nav_df["trade_date"].dt.date >= start_filter) & (nav_df["trade_date"].dt.date <= end_filter)
|
||
nav_df = nav_df.loc[mask]
|
||
pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav")
|
||
if normalize_to_one:
|
||
pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s)
|
||
import plotly.graph_objects as go
|
||
fig = go.Figure()
|
||
for col in pivot.columns:
|
||
fig.add_trace(go.Scatter(x=pivot.index, y=pivot[col], mode="lines", name=str(col)))
|
||
fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10))
|
||
if use_log_y:
|
||
fig.update_yaxes(type="log")
|
||
st.plotly_chart(fig, width='stretch')
|
||
# ADD: export pivot
|
||
try:
|
||
csv_buf = pivot.reset_index()
|
||
csv_buf.columns = ["trade_date"] + [str(c) for c in pivot.columns]
|
||
st.download_button(
|
||
"下载曲线(CSV)",
|
||
data=csv_buf.to_csv(index=False),
|
||
file_name="bt_nav_compare.csv",
|
||
mime="text/csv",
|
||
key="dl_nav_compare",
|
||
)
|
||
except Exception:
|
||
pass
|
||
except Exception: # noqa: BLE001
|
||
LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA)
|
||
if not rpt_df.empty:
|
||
try:
|
||
metrics_rows: List[Dict[str, object]] = []
|
||
for _, row in rpt_df.iterrows():
|
||
cfg_id = row["cfg_id"]
|
||
try:
|
||
summary = json.loads(row["summary"]) if isinstance(row["summary"], str) else (row["summary"] or {})
|
||
except json.JSONDecodeError:
|
||
summary = {}
|
||
record = {
|
||
"cfg_id": cfg_id,
|
||
"总收益": summary.get("total_return"),
|
||
"最大回撤": summary.get("max_drawdown"),
|
||
"交易数": summary.get("trade_count"),
|
||
"平均换手": summary.get("avg_turnover"),
|
||
"风险事件": summary.get("risk_events"),
|
||
}
|
||
metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)})
|
||
if metrics_rows:
|
||
dfm = pd.DataFrame(metrics_rows)
|
||
st.dataframe(dfm, hide_index=True, width='stretch')
|
||
try:
|
||
st.download_button(
|
||
"下载指标(CSV)",
|
||
data=dfm.to_csv(index=False),
|
||
file_name="bt_metrics_compare.csv",
|
||
mime="text/csv",
|
||
key="dl_metrics_compare",
|
||
)
|
||
except Exception:
|
||
pass
|
||
except Exception: # noqa: BLE001
|
||
LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA)
|
||
else:
|
||
st.info("请选择至少一个配置进行对比。")
|
||
|
||
|
||
|
||
with tab_rl:
|
||
st.caption("使用 DecisionEnv 对代理权重进行强化学习调参,支持单次与批量实验。")
|
||
with st.expander("离线调参实验 (DecisionEnv)", expanded=False):
|
||
st.caption(
|
||
"使用 DecisionEnv 对代理权重做离线调参。请选择需要优化的代理并设定权重范围,"
|
||
"系统会运行一次回测并返回收益、回撤等指标。若 LLM 网络不可用,将返回失败标记。"
|
||
)
|
||
|
||
disable_departments = st.checkbox(
|
||
"禁用部门 LLM(仅规则代理,适合离线快速评估)",
|
||
value=True,
|
||
help="关闭部门调用后不依赖外部 LLM 网络,仅根据规则代理权重模拟。",
|
||
)
|
||
|
||
default_experiment_id = f"streamlit_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||
experiment_id = st.text_input(
|
||
"实验 ID",
|
||
value=default_experiment_id,
|
||
help="用于在 tuning_results 表中区分不同实验。",
|
||
)
|
||
strategy_label = st.text_input(
|
||
"策略说明",
|
||
value="DecisionEnv",
|
||
help="可选:为本次调参记录一个策略名称或备注。",
|
||
)
|
||
|
||
agent_objects = default_agents()
|
||
agent_names = [agent.name for agent in agent_objects]
|
||
if not agent_names:
|
||
st.info("暂无可调整的代理。")
|
||
else:
|
||
selected_agents = st.multiselect(
|
||
"选择调参的代理权重",
|
||
agent_names,
|
||
default=agent_names[:2],
|
||
key="decision_env_agents",
|
||
)
|
||
|
||
specs: List[ParameterSpec] = []
|
||
action_values: List[float] = []
|
||
range_valid = True
|
||
for idx, agent_name in enumerate(selected_agents):
|
||
col_min, col_max, col_action = st.columns([1, 1, 2])
|
||
min_key = f"decision_env_min_{agent_name}"
|
||
max_key = f"decision_env_max_{agent_name}"
|
||
action_key = f"decision_env_action_{agent_name}"
|
||
default_min = 0.0
|
||
default_max = 1.0
|
||
min_val = col_min.number_input(
|
||
f"{agent_name} 最小权重",
|
||
min_value=0.0,
|
||
max_value=1.0,
|
||
value=default_min,
|
||
step=0.05,
|
||
key=min_key,
|
||
)
|
||
max_val = col_max.number_input(
|
||
f"{agent_name} 最大权重",
|
||
min_value=0.0,
|
||
max_value=1.0,
|
||
value=default_max,
|
||
step=0.05,
|
||
key=max_key,
|
||
)
|
||
if max_val <= min_val:
|
||
range_valid = False
|
||
action_val = col_action.slider(
|
||
f"{agent_name} 动作 (0-1)",
|
||
min_value=0.0,
|
||
max_value=1.0,
|
||
value=0.5,
|
||
step=0.01,
|
||
key=action_key,
|
||
)
|
||
specs.append(
|
||
ParameterSpec(
|
||
name=f"weight_{agent_name}",
|
||
target=f"agent_weights.{agent_name}",
|
||
minimum=min_val,
|
||
maximum=max_val,
|
||
)
|
||
)
|
||
action_values.append(action_val)
|
||
|
||
run_decision_env = st.button("执行单次调参", key="run_decision_env_button")
|
||
just_finished_single = False
|
||
if run_decision_env:
|
||
if not selected_agents:
|
||
st.warning("请至少选择一个代理进行调参。")
|
||
elif not range_valid:
|
||
st.error("请确保所有代理的最大权重大于最小权重。")
|
||
else:
|
||
LOGGER.info(
|
||
"离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s",
|
||
selected_agents,
|
||
action_values,
|
||
disable_departments,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
baseline_weights = cfg.agent_weights.as_dict()
|
||
for agent in agent_objects:
|
||
baseline_weights.setdefault(agent.name, 1.0)
|
||
|
||
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||
if not universe_env:
|
||
st.error("请先指定至少一个股票代码。")
|
||
else:
|
||
bt_cfg_env = BtConfig(
|
||
id="decision_env_streamlit",
|
||
name="DecisionEnv Streamlit",
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
universe=universe_env,
|
||
params={
|
||
"target": target,
|
||
"stop": stop,
|
||
"hold_days": int(hold_days),
|
||
},
|
||
method=cfg.decision_method,
|
||
)
|
||
env = DecisionEnv(
|
||
bt_config=bt_cfg_env,
|
||
parameter_specs=specs,
|
||
baseline_weights=baseline_weights,
|
||
disable_departments=disable_departments,
|
||
)
|
||
env.reset()
|
||
LOGGER.debug(
|
||
"离线调参(单次)启动 DecisionEnv:cfg=%s 参数维度=%s",
|
||
bt_cfg_env,
|
||
len(specs),
|
||
extra=LOG_EXTRA,
|
||
)
|
||
with st.spinner("正在执行离线调参……"):
|
||
try:
|
||
observation, reward, done, info = env.step(action_values)
|
||
LOGGER.info(
|
||
"离线调参(单次)完成,obs=%s reward=%.4f done=%s",
|
||
observation,
|
||
reward,
|
||
done,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
except Exception as exc: # noqa: BLE001
|
||
LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA)
|
||
st.error(f"离线调参失败:{exc}")
|
||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||
else:
|
||
if observation.get("failure"):
|
||
st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。")
|
||
st.json(observation)
|
||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||
else:
|
||
resolved_experiment_id = experiment_id or str(uuid.uuid4())
|
||
resolved_strategy = strategy_label or "DecisionEnv"
|
||
action_payload = {
|
||
name: value
|
||
for name, value in zip(selected_agents, action_values)
|
||
}
|
||
metrics_payload = dict(observation)
|
||
metrics_payload["reward"] = reward
|
||
log_success = False
|
||
try:
|
||
log_tuning_result(
|
||
experiment_id=resolved_experiment_id,
|
||
strategy=resolved_strategy,
|
||
action=action_payload,
|
||
reward=reward,
|
||
metrics=metrics_payload,
|
||
weights=info.get("weights", {}),
|
||
)
|
||
except Exception: # noqa: BLE001
|
||
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
|
||
else:
|
||
log_success = True
|
||
LOGGER.info(
|
||
"离线调参(单次)日志写入成功:experiment=%s strategy=%s",
|
||
resolved_experiment_id,
|
||
resolved_strategy,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
st.session_state[_DECISION_ENV_SINGLE_RESULT_KEY] = {
|
||
"observation": dict(observation),
|
||
"reward": float(reward),
|
||
"weights": info.get("weights", {}),
|
||
"nav_series": info.get("nav_series"),
|
||
"trades": info.get("trades"),
|
||
"portfolio_snapshots": info.get("portfolio_snapshots"),
|
||
"portfolio_trades": info.get("portfolio_trades"),
|
||
"risk_breakdown": info.get("risk_breakdown"),
|
||
"selected_agents": list(selected_agents),
|
||
"action_values": list(action_values),
|
||
"experiment_id": resolved_experiment_id,
|
||
"strategy_label": resolved_strategy,
|
||
"logged": log_success,
|
||
}
|
||
just_finished_single = True
|
||
single_result = st.session_state.get(_DECISION_ENV_SINGLE_RESULT_KEY)
|
||
if single_result:
|
||
if just_finished_single:
|
||
st.success("离线调参完成")
|
||
else:
|
||
st.success("离线调参结果(最近一次运行)")
|
||
st.caption(
|
||
f"实验 ID:{single_result.get('experiment_id', '-') } | 策略:{single_result.get('strategy_label', 'DecisionEnv')}"
|
||
)
|
||
observation = single_result.get("observation", {})
|
||
reward = float(single_result.get("reward", 0.0))
|
||
col_metrics = st.columns(4)
|
||
col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}")
|
||
col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}")
|
||
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
|
||
col_metrics[3].metric("奖励", f"{reward:+.4f}")
|
||
|
||
turnover_ratio = float(observation.get("turnover", 0.0) or 0.0)
|
||
turnover_value = float(observation.get("turnover_value", 0.0) or 0.0)
|
||
risk_count = float(observation.get("risk_count", 0.0) or 0.0)
|
||
col_metrics_extra = st.columns(3)
|
||
col_metrics_extra[0].metric("平均换手率", f"{turnover_ratio:.2%}")
|
||
col_metrics_extra[1].metric("成交额", f"{turnover_value:,.0f}")
|
||
col_metrics_extra[2].metric("风险事件数", f"{int(risk_count)}")
|
||
|
||
weights_dict = single_result.get("weights") or {}
|
||
if weights_dict:
|
||
st.write("调参后权重:")
|
||
st.json(weights_dict)
|
||
if st.button("保存这些权重为默认配置", key="save_decision_env_weights_single"):
|
||
try:
|
||
cfg.agent_weights.update_from_dict(weights_dict)
|
||
save_config(cfg)
|
||
except Exception as exc: # noqa: BLE001
|
||
LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
|
||
st.error(f"写入配置失败:{exc}")
|
||
else:
|
||
st.success("代理权重已写入 config.json")
|
||
|
||
if single_result.get("logged"):
|
||
st.caption("调参结果已写入 tuning_results 表。")
|
||
|
||
nav_series = single_result.get("nav_series") or []
|
||
if nav_series:
|
||
try:
|
||
nav_df = pd.DataFrame(nav_series)
|
||
if {"trade_date", "nav"}.issubset(nav_df.columns):
|
||
nav_df = nav_df.sort_values("trade_date")
|
||
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"])
|
||
st.line_chart(nav_df.set_index("trade_date")["nav"], height=220)
|
||
except Exception: # noqa: BLE001
|
||
LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA)
|
||
|
||
trades = single_result.get("trades") or []
|
||
if trades:
|
||
st.write("成交记录:")
|
||
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
|
||
|
||
snapshots = single_result.get("portfolio_snapshots") or []
|
||
if snapshots:
|
||
with st.expander("投资组合快照", expanded=False):
|
||
st.dataframe(pd.DataFrame(snapshots), hide_index=True, width='stretch')
|
||
|
||
portfolio_trades = single_result.get("portfolio_trades") or []
|
||
if portfolio_trades:
|
||
with st.expander("组合成交明细", expanded=False):
|
||
st.dataframe(pd.DataFrame(portfolio_trades), hide_index=True, width='stretch')
|
||
|
||
risk_breakdown = single_result.get("risk_breakdown") or {}
|
||
if risk_breakdown:
|
||
with st.expander("风险事件统计", expanded=False):
|
||
st.json(risk_breakdown)
|
||
|
||
if st.button("清除单次调参结果", key="clear_decision_env_single"):
|
||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||
st.success("已清除单次调参结果缓存。")
|
||
|
||
st.divider()
|
||
st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。")
|
||
default_grid = "\n".join(
|
||
[
|
||
",".join(["0.2" for _ in specs]),
|
||
",".join(["0.5" for _ in specs]),
|
||
",".join(["0.8" for _ in specs]),
|
||
]
|
||
) if specs else ""
|
||
action_grid_raw = st.text_area(
|
||
"动作列表",
|
||
value=default_grid,
|
||
height=120,
|
||
key="decision_env_batch_actions",
|
||
)
|
||
run_batch = st.button("批量执行调参", key="run_decision_env_batch")
|
||
batch_just_ran = False
|
||
if run_batch:
|
||
if not selected_agents:
|
||
st.warning("请先选择调参代理。")
|
||
elif not range_valid:
|
||
st.error("请确保所有代理的最大权重大于最小权重。")
|
||
else:
|
||
LOGGER.info(
|
||
"离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s",
|
||
selected_agents,
|
||
disable_departments,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()]
|
||
if not lines:
|
||
st.warning("请在文本框中输入至少一组动作。")
|
||
else:
|
||
LOGGER.debug(
|
||
"离线调参(批量)原始输入=%s",
|
||
lines,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
parsed_actions: List[List[float]] = []
|
||
for line in lines:
|
||
try:
|
||
values = [float(val.strip()) for val in line.split(',') if val.strip()]
|
||
except ValueError:
|
||
st.error(f"无法解析动作行:{line}")
|
||
parsed_actions = []
|
||
break
|
||
if len(values) != len(specs):
|
||
st.error(f"动作维度不匹配(期望 {len(specs)} 个值):{line}")
|
||
parsed_actions = []
|
||
break
|
||
parsed_actions.append(values)
|
||
if parsed_actions:
|
||
LOGGER.info(
|
||
"离线调参(批量)解析动作成功,数量=%s",
|
||
len(parsed_actions),
|
||
extra=LOG_EXTRA,
|
||
)
|
||
baseline_weights = cfg.agent_weights.as_dict()
|
||
for agent in agent_objects:
|
||
baseline_weights.setdefault(agent.name, 1.0)
|
||
|
||
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||
if not universe_env:
|
||
st.error("请先指定至少一个股票代码。")
|
||
else:
|
||
bt_cfg_env = BtConfig(
|
||
id="decision_env_streamlit_batch",
|
||
name="DecisionEnv Batch",
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
universe=universe_env,
|
||
params={
|
||
"target": target,
|
||
"stop": stop,
|
||
"hold_days": int(hold_days),
|
||
},
|
||
method=cfg.decision_method,
|
||
)
|
||
env = DecisionEnv(
|
||
bt_config=bt_cfg_env,
|
||
parameter_specs=specs,
|
||
baseline_weights=baseline_weights,
|
||
disable_departments=disable_departments,
|
||
)
|
||
results: List[Dict[str, object]] = []
|
||
resolved_experiment_id = experiment_id or str(uuid.uuid4())
|
||
resolved_strategy = strategy_label or "DecisionEnv"
|
||
LOGGER.debug(
|
||
"离线调参(批量)启动 DecisionEnv:cfg=%s 动作组=%s",
|
||
bt_cfg_env,
|
||
len(parsed_actions),
|
||
extra=LOG_EXTRA,
|
||
)
|
||
with st.spinner("正在批量执行调参……"):
|
||
for idx, action_vals in enumerate(parsed_actions, start=1):
|
||
env.reset()
|
||
try:
|
||
observation, reward, done, info = env.step(action_vals)
|
||
except Exception as exc: # noqa: BLE001
|
||
LOGGER.exception("批量调参失败", extra=LOG_EXTRA)
|
||
results.append(
|
||
{
|
||
"序号": idx,
|
||
"动作": action_vals,
|
||
"状态": "error",
|
||
"错误": str(exc),
|
||
}
|
||
)
|
||
continue
|
||
if observation.get("failure"):
|
||
results.append(
|
||
{
|
||
"序号": idx,
|
||
"动作": action_vals,
|
||
"状态": "failure",
|
||
"奖励": -1.0,
|
||
}
|
||
)
|
||
else:
|
||
LOGGER.info(
|
||
"离线调参(批量)第 %s 组完成,reward=%.4f obs=%s",
|
||
idx,
|
||
reward,
|
||
observation,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
action_payload = {
|
||
name: value
|
||
for name, value in zip(selected_agents, action_vals)
|
||
}
|
||
metrics_payload = dict(observation)
|
||
metrics_payload["reward"] = reward
|
||
weights_payload = info.get("weights", {})
|
||
try:
|
||
log_tuning_result(
|
||
experiment_id=resolved_experiment_id,
|
||
strategy=resolved_strategy,
|
||
action=action_payload,
|
||
reward=reward,
|
||
metrics=metrics_payload,
|
||
weights=weights_payload,
|
||
)
|
||
except Exception: # noqa: BLE001
|
||
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
|
||
results.append(
|
||
{
|
||
"序号": idx,
|
||
"动作": action_vals,
|
||
"状态": "ok",
|
||
"总收益": observation.get("total_return", 0.0),
|
||
"最大回撤": observation.get("max_drawdown", 0.0),
|
||
"波动率": observation.get("volatility", 0.0),
|
||
"奖励": reward,
|
||
"权重": weights_payload,
|
||
}
|
||
)
|
||
st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = {
|
||
"results": results,
|
||
"selectable": [
|
||
row
|
||
for row in results
|
||
if row.get("状态") == "ok" and row.get("权重")
|
||
],
|
||
"experiment_id": resolved_experiment_id,
|
||
"strategy_label": resolved_strategy,
|
||
}
|
||
batch_just_ran = True
|
||
LOGGER.info(
|
||
"离线调参(批量)执行结束,总结果条数=%s",
|
||
len(results),
|
||
extra=LOG_EXTRA,
|
||
)
|
||
batch_state = st.session_state.get(_DECISION_ENV_BATCH_RESULTS_KEY)
|
||
if batch_state:
|
||
results = batch_state.get("results") or []
|
||
if results:
|
||
if batch_just_ran:
|
||
st.success("批量调参完成")
|
||
else:
|
||
st.success("批量调参结果(最近一次运行)")
|
||
st.caption(
|
||
f"实验 ID:{batch_state.get('experiment_id', '-') } | 策略:{batch_state.get('strategy_label', 'DecisionEnv')}"
|
||
)
|
||
results_df = pd.DataFrame(results)
|
||
st.write("批量调参结果:")
|
||
st.dataframe(results_df, hide_index=True, width='stretch')
|
||
selectable = batch_state.get("selectable") or []
|
||
if selectable:
|
||
option_labels = [
|
||
f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}"
|
||
for row in selectable
|
||
]
|
||
selected_label = st.selectbox(
|
||
"选择要保存的记录",
|
||
option_labels,
|
||
key="decision_env_batch_select",
|
||
)
|
||
selected_row = None
|
||
for label, row in zip(option_labels, selectable):
|
||
if label == selected_label:
|
||
selected_row = row
|
||
break
|
||
if selected_row and st.button(
|
||
"保存所选权重为默认配置",
|
||
key="save_decision_env_weights_batch",
|
||
):
|
||
try:
|
||
cfg.agent_weights.update_from_dict(selected_row.get("权重", {}))
|
||
save_config(cfg)
|
||
except Exception as exc: # noqa: BLE001
|
||
LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
|
||
st.error(f"写入配置失败:{exc}")
|
||
else:
|
||
st.success(
|
||
f"已将序号 {selected_row['序号']} 的权重写入 config.json"
|
||
)
|
||
else:
|
||
st.caption("暂无成功的结果可供保存。")
|
||
else:
|
||
st.caption("批量调参在最近一次执行中未产生结果。")
|
||
if st.button("清除批量调参结果", key="clear_decision_env_batch"):
|
||
st.session_state.pop(_DECISION_ENV_BATCH_RESULTS_KEY, None)
|
||
st.session_state.pop("decision_env_batch_select", None)
|
||
st.success("已清除批量调参结果缓存。")
|