llm-quant/app/ui/views/backtest.py
2025-10-06 13:39:07 +08:00

1066 lines
56 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""回测与复盘相关视图。"""
from __future__ import annotations
import json
import uuid
from dataclasses import asdict
from datetime import date, datetime
from typing import Dict, List, Optional
import numpy as np
import pandas as pd
import plotly.express as px
import requests
from requests.exceptions import RequestException
import streamlit as st
from app.agents.base import AgentContext
from app.agents.game import Decision
from app.agents.registry import default_agents
from app.backtest.decision_env import DecisionEnv, ParameterSpec
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig, run_backtest
from app.ingest.checker import run_boot_check
from app.ingest.tushare import run_ingestion
from app.llm.client import run_llm
from app.llm.metrics import reset as reset_llm_metrics
from app.llm.metrics import snapshot as snapshot_llm_metrics
from app.utils import alerts
from app.utils.config import get_config, save_config
from app.utils.tuning import log_tuning_result
from app.utils.db import db_session
from app.ui.shared import LOGGER, LOG_EXTRA, default_backtest_range
from app.ui.views.dashboard import update_dashboard_sidebar
_DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result"
_DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
_DECISION_ENV_BANDIT_RESULTS_KEY = "decision_env_bandit_results"
def render_backtest_review() -> None:
"""渲染回测执行、调参与结果复盘页面。"""
st.header("回测与复盘")
st.caption("1. 基于历史数据复盘当前策略2. 借助强化学习/调参探索更优参数组合。")
app_cfg = get_config()
default_start, default_end = default_backtest_range(window_days=60)
LOGGER.debug(
"回测默认参数start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
default_start,
default_end,
"000001.SZ",
0.035,
-0.015,
10,
extra=LOG_EXTRA,
)
st.markdown("### 回测参数")
col1, col2 = st.columns(2)
start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date")
end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date")
universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe")
col_target, col_stop, col_hold = st.columns(3)
target = col_target.number_input("目标收益0.035 表示 3.5%", value=0.035, step=0.005, format="%.3f", key="bt_target")
stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%", value=-0.015, step=0.005, format="%.3f", key="bt_stop")
hold_days = col_hold.number_input("持有期(交易日)", value=10, step=1, key="bt_hold_days")
LOGGER.debug(
"当前回测表单输入start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s",
start_date,
end_date,
universe_text,
target,
stop,
hold_days,
extra=LOG_EXTRA,
)
tab_backtest, tab_rl = st.tabs(["回测验证", "强化学习调参"])
with tab_backtest:
st.markdown("#### 回测执行")
if st.button("运行回测", key="bt_run_button"):
LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA)
decision_log_container = st.container()
status_box = st.status("准备执行回测...", expanded=True)
llm_stats_placeholder = st.empty()
decision_entries: List[str] = []
def _decision_callback(ts_code: str, trade_dt: date, ctx: AgentContext, decision: Decision) -> None:
ts_label = trade_dt.isoformat()
summary = ""
for dept_decision in decision.department_decisions.values():
if getattr(dept_decision, "summary", ""):
summary = str(dept_decision.summary)
break
entry_lines = [
f"**{ts_label} {ts_code}** → {decision.action.value} (信心 {decision.confidence:.2f})",
]
if summary:
entry_lines.append(f"摘要:{summary}")
dep_highlights = []
for dept_code, dept_decision in decision.department_decisions.items():
dep_highlights.append(
f"{dept_code}:{dept_decision.action.value}({dept_decision.confidence:.2f})"
)
if dep_highlights:
entry_lines.append("部门意见:" + "".join(dep_highlights))
decision_entries.append(" \n".join(entry_lines))
decision_log_container.markdown("\n\n".join(decision_entries[-200:]))
status_box.write(f"{ts_label} {ts_code}{decision.action.value} (信心 {decision.confidence:.2f})")
stats = snapshot_llm_metrics()
llm_stats_placeholder.json(
{
"LLM 调用次数": stats.get("total_calls", 0),
"Prompt Tokens": stats.get("total_prompt_tokens", 0),
"Completion Tokens": stats.get("total_completion_tokens", 0),
"按 Provider": stats.get("provider_calls", {}),
"按模型": stats.get("model_calls", {}),
}
)
update_dashboard_sidebar(stats)
reset_llm_metrics()
status_box.update(label="执行回测中...", state="running")
try:
universe = [code.strip() for code in universe_text.split(',') if code.strip()]
LOGGER.info(
"回测参数start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
start_date,
end_date,
universe,
target,
stop,
hold_days,
extra=LOG_EXTRA,
)
backtest_cfg = BtConfig(
id="streamlit_demo",
name="Streamlit Demo Strategy",
start_date=start_date,
end_date=end_date,
universe=universe,
params={
"target": target,
"stop": stop,
"hold_days": int(hold_days),
},
)
result = run_backtest(backtest_cfg, decision_callback=_decision_callback)
LOGGER.info(
"回测完成nav_records=%s trades=%s",
len(result.nav_series),
len(result.trades),
extra=LOG_EXTRA,
)
status_box.update(label="回测执行完成", state="complete")
st.success("回测执行完成,详见下方结果与统计。")
metrics = snapshot_llm_metrics()
llm_stats_placeholder.json(
{
"LLM 调用次数": metrics.get("total_calls", 0),
"Prompt Tokens": metrics.get("total_prompt_tokens", 0),
"Completion Tokens": metrics.get("total_completion_tokens", 0),
"按 Provider": metrics.get("provider_calls", {}),
"按模型": metrics.get("model_calls", {}),
}
)
update_dashboard_sidebar(metrics)
st.session_state["backtest_last_result"] = {"nav_records": result.nav_series, "trades": result.trades}
st.json(st.session_state["backtest_last_result"])
except Exception as exc: # noqa: BLE001
LOGGER.exception("回测执行失败", extra=LOG_EXTRA)
status_box.update(label="回测执行失败", state="error")
st.error(f"回测执行失败:{exc}")
last_result = st.session_state.get("backtest_last_result")
if last_result:
st.markdown("#### 最近回测输出")
st.json(last_result)
st.divider()
# ADD: Comparison view for multiple backtest configurations
st.caption("从历史回测配置中选择多个进行净值曲线与指标对比。")
normalize_to_one = st.checkbox("归一化到 1 起点", value=True, key="bt_cmp_normalize")
use_log_y = st.checkbox("对数坐标", value=False, key="bt_cmp_log_y")
metric_options = ["总收益", "最大回撤", "交易数", "平均换手", "风险事件"]
selected_metrics = st.multiselect("显示指标列", metric_options, default=metric_options, key="bt_cmp_metrics")
try:
with db_session(read_only=True) as conn:
cfg_rows = conn.execute(
"SELECT id, name FROM bt_config ORDER BY rowid DESC LIMIT 50"
).fetchall()
except Exception: # noqa: BLE001
LOGGER.exception("读取 bt_config 失败", extra=LOG_EXTRA)
cfg_rows = []
cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows]
selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2], key="bt_cmp_configs")
selected_ids = [label.split(" | ")[0].strip() for label in selected_labels]
nav_df = pd.DataFrame()
rpt_df = pd.DataFrame()
if selected_ids:
try:
with db_session(read_only=True) as conn:
nav_df = pd.read_sql_query(
"SELECT cfg_id, trade_date, nav FROM bt_nav WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
conn,
params=tuple(selected_ids),
)
rpt_df = pd.read_sql_query(
"SELECT cfg_id, summary FROM bt_report WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
conn,
params=tuple(selected_ids),
)
except Exception: # noqa: BLE001
LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA)
st.error("读取回测结果失败")
nav_df = pd.DataFrame()
rpt_df = pd.DataFrame()
if not nav_df.empty:
try:
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce")
# ADD: date window filter
overall_min = pd.to_datetime(nav_df["trade_date"].min()).date()
overall_max = pd.to_datetime(nav_df["trade_date"].max()).date()
col_d1, col_d2 = st.columns(2)
start_filter = col_d1.date_input("起始日期", value=overall_min, key="bt_cmp_start")
end_filter = col_d2.date_input("结束日期", value=overall_max, key="bt_cmp_end")
if start_filter > end_filter:
start_filter, end_filter = end_filter, start_filter
mask = (nav_df["trade_date"].dt.date >= start_filter) & (nav_df["trade_date"].dt.date <= end_filter)
nav_df = nav_df.loc[mask]
pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav")
if normalize_to_one:
pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s)
import plotly.graph_objects as go
fig = go.Figure()
for col in pivot.columns:
fig.add_trace(go.Scatter(x=pivot.index, y=pivot[col], mode="lines", name=str(col)))
fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10))
if use_log_y:
fig.update_yaxes(type="log")
st.plotly_chart(fig, width='stretch')
# ADD: export pivot
try:
csv_buf = pivot.reset_index()
csv_buf.columns = ["trade_date"] + [str(c) for c in pivot.columns]
st.download_button(
"下载曲线(CSV)",
data=csv_buf.to_csv(index=False),
file_name="bt_nav_compare.csv",
mime="text/csv",
key="dl_nav_compare",
)
except Exception:
pass
except Exception: # noqa: BLE001
LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA)
if not rpt_df.empty:
try:
metrics_rows: List[Dict[str, object]] = []
for _, row in rpt_df.iterrows():
cfg_id = row["cfg_id"]
try:
summary = json.loads(row["summary"]) if isinstance(row["summary"], str) else (row["summary"] or {})
except json.JSONDecodeError:
summary = {}
record = {
"cfg_id": cfg_id,
"总收益": summary.get("total_return"),
"最大回撤": summary.get("max_drawdown"),
"交易数": summary.get("trade_count"),
"平均换手": summary.get("avg_turnover"),
"风险事件": summary.get("risk_events"),
}
metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)})
if metrics_rows:
dfm = pd.DataFrame(metrics_rows)
st.dataframe(dfm, hide_index=True, width='stretch')
try:
st.download_button(
"下载指标(CSV)",
data=dfm.to_csv(index=False),
file_name="bt_metrics_compare.csv",
mime="text/csv",
key="dl_metrics_compare",
)
except Exception:
pass
except Exception: # noqa: BLE001
LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA)
else:
st.info("请选择至少一个配置进行对比。")
with tab_rl:
st.caption("使用 DecisionEnv 对代理权重进行强化学习调参,支持单次与批量实验。")
disable_departments = st.checkbox(
"禁用部门 LLM仅规则代理适合离线快速评估",
value=True,
help="关闭部门调用后不依赖外部 LLM 网络,仅根据规则代理权重模拟。",
)
default_experiment_id = f"streamlit_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
experiment_id = st.text_input(
"实验 ID",
value=default_experiment_id,
help="用于在 tuning_results 表中区分不同实验。",
)
strategy_label = st.text_input(
"策略说明",
value="DecisionEnv",
help="可选:为本次调参记录一个策略名称或备注。",
)
agent_objects = default_agents()
agent_names = [agent.name for agent in agent_objects]
if not agent_names:
st.info("暂无可调整的代理。")
else:
selected_agents = st.multiselect(
"选择调参的代理权重",
agent_names,
default=agent_names[:2],
key="decision_env_agents",
)
specs: List[ParameterSpec] = []
spec_labels: List[str] = []
action_values: List[float] = []
range_valid = True
for idx, agent_name in enumerate(selected_agents):
col_min, col_max, col_action = st.columns([1, 1, 2])
min_key = f"decision_env_min_{agent_name}"
max_key = f"decision_env_max_{agent_name}"
action_key = f"decision_env_action_{agent_name}"
default_min = 0.0
default_max = 1.0
min_val = col_min.number_input(
f"{agent_name} 最小权重",
min_value=0.0,
max_value=1.0,
value=default_min,
step=0.05,
key=min_key,
)
max_val = col_max.number_input(
f"{agent_name} 最大权重",
min_value=0.0,
max_value=1.0,
value=default_max,
step=0.05,
key=max_key,
)
if max_val <= min_val:
range_valid = False
action_val = col_action.slider(
f"{agent_name} 动作 (0-1)",
min_value=0.0,
max_value=1.0,
value=0.5,
step=0.01,
key=action_key,
)
specs.append(
ParameterSpec(
name=f"weight_{agent_name}",
target=f"agent_weights.{agent_name}",
minimum=min_val,
maximum=max_val,
)
)
spec_labels.append(f"agent:{agent_name}")
action_values.append(action_val)
controls_valid = True
with st.expander("部门 LLM 参数", expanded=False):
dept_codes = sorted(app_cfg.departments.keys())
if not dept_codes:
st.caption("当前未配置部门。")
else:
selected_departments = st.multiselect(
"选择需要调整的部门",
dept_codes,
default=[],
key="decision_env_departments",
)
tool_policy_values = ["auto", "none", "required"]
for dept_code in selected_departments:
settings = app_cfg.departments.get(dept_code)
if not settings:
continue
st.subheader(f"部门:{settings.title or dept_code}")
base_temp = 0.2
if settings.llm and settings.llm.primary and settings.llm.primary.temperature is not None:
base_temp = float(settings.llm.primary.temperature)
prefix = f"decision_env_dept_{dept_code}"
col_tmin, col_tmax, col_tslider = st.columns([1, 1, 2])
temp_min = col_tmin.number_input(
"温度最小值",
min_value=0.0,
max_value=2.0,
value=max(0.0, base_temp - 0.3),
step=0.05,
key=f"{prefix}_temp_min",
)
temp_max = col_tmax.number_input(
"温度最大值",
min_value=0.0,
max_value=2.0,
value=min(2.0, base_temp + 0.3),
step=0.05,
key=f"{prefix}_temp_max",
)
if temp_max <= temp_min:
controls_valid = False
st.warning("温度最大值必须大于最小值。")
temp_max = min(2.0, temp_min + 0.01)
span = temp_max - temp_min
if span <= 0:
ratio_default = 0.0
else:
clamped = min(max(base_temp, temp_min), temp_max)
ratio_default = (clamped - temp_min) / span
temp_action = col_tslider.slider(
"动作值(映射至温度区间)",
min_value=0.0,
max_value=1.0,
value=float(ratio_default),
step=0.01,
key=f"{prefix}_temp_action",
)
specs.append(
ParameterSpec(
name=f"dept_temperature_{dept_code}",
target=f"department.{dept_code}.temperature",
minimum=temp_min,
maximum=temp_max,
)
)
spec_labels.append(f"department:{dept_code}:temperature")
action_values.append(temp_action)
col_tool, col_hint = st.columns([1, 2])
tool_choice = col_tool.selectbox(
"函数调用策略",
tool_policy_values,
index=tool_policy_values.index("auto"),
key=f"{prefix}_tool_choice",
)
col_hint.caption("映射提示0→auto0.5→none1→required。")
if len(tool_policy_values) > 1:
tool_value = tool_policy_values.index(tool_choice) / (len(tool_policy_values) - 1)
else:
tool_value = 0.0
specs.append(
ParameterSpec(
name=f"dept_tool_{dept_code}",
target=f"department.{dept_code}.function_policy",
values=tool_policy_values,
)
)
spec_labels.append(f"department:{dept_code}:tool_choice")
action_values.append(tool_value)
if specs:
st.caption("动作维度顺序:" + "".join(spec_labels))
run_decision_env = st.button("执行单次调参", key="run_decision_env_button")
just_finished_single = False
if run_decision_env:
if not specs:
st.warning("请至少配置一个动作维度(代理或部门参数)。")
elif selected_agents and not range_valid:
st.error("请确保所有代理的最大权重大于最小权重。")
elif not controls_valid:
st.error("请修正部门参数的取值范围。")
else:
LOGGER.info(
"离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s",
selected_agents,
action_values,
disable_departments,
extra=LOG_EXTRA,
)
baseline_weights = cfg.agent_weights.as_dict()
for agent in agent_objects:
baseline_weights.setdefault(agent.name, 1.0)
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
if not universe_env:
st.error("请先指定至少一个股票代码。")
else:
bt_cfg_env = BtConfig(
id="decision_env_streamlit",
name="DecisionEnv Streamlit",
start_date=start_date,
end_date=end_date,
universe=universe_env,
params={
"target": target,
"stop": stop,
"hold_days": int(hold_days),
},
method=cfg.decision_method,
)
env = DecisionEnv(
bt_config=bt_cfg_env,
parameter_specs=specs,
baseline_weights=baseline_weights,
disable_departments=disable_departments,
)
env.reset()
LOGGER.debug(
"离线调参(单次)启动 DecisionEnvcfg=%s 参数维度=%s",
bt_cfg_env,
len(specs),
extra=LOG_EXTRA,
)
with st.spinner("正在执行离线调参……"):
try:
observation, reward, done, info = env.step(action_values)
LOGGER.info(
"离线调参单次完成obs=%s reward=%.4f done=%s",
observation,
reward,
done,
extra=LOG_EXTRA,
)
except Exception as exc: # noqa: BLE001
LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA)
st.error(f"离线调参失败:{exc}")
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
else:
if observation.get("failure"):
st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。")
st.json(observation)
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
else:
resolved_experiment_id = experiment_id or str(uuid.uuid4())
resolved_strategy = strategy_label or "DecisionEnv"
action_payload = {
label: value for label, value in zip(spec_labels, action_values)
}
metrics_payload = dict(observation)
metrics_payload["reward"] = reward
metrics_payload["department_controls"] = info.get("department_controls")
log_success = False
try:
log_tuning_result(
experiment_id=resolved_experiment_id,
strategy=resolved_strategy,
action=action_payload,
reward=reward,
metrics=metrics_payload,
weights=info.get("weights", {}),
)
except Exception: # noqa: BLE001
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
else:
log_success = True
LOGGER.info(
"离线调参单次日志写入成功experiment=%s strategy=%s",
resolved_experiment_id,
resolved_strategy,
extra=LOG_EXTRA,
)
st.session_state[_DECISION_ENV_SINGLE_RESULT_KEY] = {
"observation": dict(observation),
"reward": float(reward),
"weights": info.get("weights", {}),
"department_controls": info.get("department_controls"),
"actions": action_payload,
"nav_series": info.get("nav_series"),
"trades": info.get("trades"),
"portfolio_snapshots": info.get("portfolio_snapshots"),
"portfolio_trades": info.get("portfolio_trades"),
"risk_breakdown": info.get("risk_breakdown"),
"spec_labels": list(spec_labels),
"action_values": list(action_values),
"experiment_id": resolved_experiment_id,
"strategy_label": resolved_strategy,
"logged": log_success,
}
just_finished_single = True
single_result = st.session_state.get(_DECISION_ENV_SINGLE_RESULT_KEY)
if single_result:
if just_finished_single:
st.success("离线调参完成")
else:
st.success("离线调参结果(最近一次运行)")
st.caption(
f"实验 ID{single_result.get('experiment_id', '-') } | 策略:{single_result.get('strategy_label', 'DecisionEnv')}"
)
observation = single_result.get("observation", {})
reward = float(single_result.get("reward", 0.0))
col_metrics = st.columns(4)
col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}")
col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}")
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
col_metrics[3].metric("奖励", f"{reward:+.4f}")
turnover_ratio = float(observation.get("turnover", 0.0) or 0.0)
turnover_value = float(observation.get("turnover_value", 0.0) or 0.0)
risk_count = float(observation.get("risk_count", 0.0) or 0.0)
col_metrics_extra = st.columns(3)
col_metrics_extra[0].metric("平均换手率", f"{turnover_ratio:.2%}")
col_metrics_extra[1].metric("成交额", f"{turnover_value:,.0f}")
col_metrics_extra[2].metric("风险事件数", f"{int(risk_count)}")
weights_dict = single_result.get("weights") or {}
if weights_dict:
st.write("调参后权重:")
st.json(weights_dict)
if st.button("保存这些权重为默认配置", key="save_decision_env_weights_single"):
try:
cfg.agent_weights.update_from_dict(weights_dict)
save_config(cfg)
except Exception as exc: # noqa: BLE001
LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"写入配置失败:{exc}")
else:
st.success("代理权重已写入 config.json")
if single_result.get("logged"):
st.caption("调参结果已写入 tuning_results 表。")
nav_series = single_result.get("nav_series") or []
if nav_series:
try:
nav_df = pd.DataFrame(nav_series)
if {"trade_date", "nav"}.issubset(nav_df.columns):
nav_df = nav_df.sort_values("trade_date")
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"])
st.line_chart(nav_df.set_index("trade_date")["nav"], height=220)
except Exception: # noqa: BLE001
LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA)
trades = single_result.get("trades") or []
if trades:
st.write("成交记录:")
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
snapshots = single_result.get("portfolio_snapshots") or []
if snapshots:
with st.expander("投资组合快照", expanded=False):
st.dataframe(pd.DataFrame(snapshots), hide_index=True, width='stretch')
portfolio_trades = single_result.get("portfolio_trades") or []
if portfolio_trades:
with st.expander("组合成交明细", expanded=False):
st.dataframe(pd.DataFrame(portfolio_trades), hide_index=True, width='stretch')
risk_breakdown = single_result.get("risk_breakdown") or {}
if risk_breakdown:
with st.expander("风险事件统计", expanded=False):
st.json(risk_breakdown)
department_info = single_result.get("department_controls") or {}
if department_info:
with st.expander("部门控制参数", expanded=False):
st.json(department_info)
action_snapshot = single_result.get("actions") or {}
if action_snapshot:
with st.expander("动作明细", expanded=False):
st.json(action_snapshot)
if st.button("清除单次调参结果", key="clear_decision_env_single"):
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
st.success("已清除单次调参结果缓存。")
st.divider()
st.subheader("自动探索epsilon-greedy")
col_ep, col_eps, col_seed = st.columns([1, 1, 1])
bandit_episodes = int(
col_ep.number_input(
"迭代次数",
min_value=1,
max_value=200,
value=10,
step=1,
key="decision_env_bandit_episodes",
help="探索的回合数,越大越充分但耗时越久。",
)
)
bandit_epsilon = float(
col_eps.slider(
"探索比例 ε",
min_value=0.0,
max_value=1.0,
value=0.2,
step=0.05,
key="decision_env_bandit_epsilon",
help="ε 越大,随机探索概率越高。",
)
)
seed_text = col_seed.text_input(
"随机种子(可选)",
value="",
key="decision_env_bandit_seed",
help="填写整数可复现实验,不填写则随机。",
).strip()
bandit_seed = None
if seed_text:
try:
bandit_seed = int(seed_text)
except ValueError:
st.warning("随机种子需为整数,已忽略该值。")
bandit_seed = None
run_bandit = st.button("执行自动探索", key="run_decision_env_bandit")
if run_bandit:
if not specs:
st.warning("请至少配置一个动作维度再执行探索。")
elif selected_agents and not range_valid:
st.error("请确保所有代理的最大权重大于最小权重。")
elif not controls_valid:
st.error("请修正部门参数的取值范围。")
else:
baseline_weights = cfg.agent_weights.as_dict()
for agent in agent_objects:
baseline_weights.setdefault(agent.name, 1.0)
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
if not universe_env:
st.error("请先指定至少一个股票代码。")
else:
bt_cfg_env = BtConfig(
id="decision_env_bandit",
name="DecisionEnv Bandit",
start_date=start_date,
end_date=end_date,
universe=universe_env,
params={
"target": target,
"stop": stop,
"hold_days": int(hold_days),
},
method=cfg.decision_method,
)
env = DecisionEnv(
bt_config=bt_cfg_env,
parameter_specs=specs,
baseline_weights=baseline_weights,
disable_departments=disable_departments,
)
config = BanditConfig(
experiment_id=experiment_id or f"bandit_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
strategy=strategy_label or "DecisionEnv",
episodes=bandit_episodes,
epsilon=bandit_epsilon,
seed=bandit_seed,
)
bandit = EpsilonGreedyBandit(env, config)
with st.spinner("自动探索进行中,请稍候..."):
summary = bandit.run()
episodes_dump: List[Dict[str, object]] = []
for idx, episode in enumerate(summary.episodes, start=1):
episodes_dump.append(
{
"序号": idx,
"奖励": episode.reward,
"动作(raw)": json.dumps(episode.action, ensure_ascii=False),
"参数值": json.dumps(episode.resolved_action, ensure_ascii=False),
"总收益": episode.metrics.total_return,
"最大回撤": episode.metrics.max_drawdown,
"波动率": episode.metrics.volatility,
"权重": json.dumps(episode.weights or {}, ensure_ascii=False),
"部门控制": json.dumps(episode.department_controls or {}, ensure_ascii=False),
}
)
best_episode = summary.best_episode
best_index = summary.episodes.index(best_episode) + 1 if best_episode else None
st.session_state[_DECISION_ENV_BANDIT_RESULTS_KEY] = {
"episodes": episodes_dump,
"best_index": best_index,
"best": {
"reward": best_episode.reward if best_episode else None,
"action": best_episode.action if best_episode else None,
"resolved_action": best_episode.resolved_action if best_episode else None,
"weights": best_episode.weights if best_episode else None,
"department_controls": best_episode.department_controls if best_episode else None,
},
"experiment_id": config.experiment_id,
"strategy": config.strategy,
}
st.success(f"自动探索完成,共执行 {len(episodes_dump)} 轮。")
bandit_state = st.session_state.get(_DECISION_ENV_BANDIT_RESULTS_KEY)
if bandit_state:
st.caption(
f"实验 ID{bandit_state.get('experiment_id')} | 策略:{bandit_state.get('strategy')}"
)
episodes_dump = bandit_state.get("episodes") or []
if episodes_dump:
st.dataframe(pd.DataFrame(episodes_dump), hide_index=True, width='stretch')
best_payload = bandit_state.get("best") or {}
if best_payload.get("reward") is not None:
st.success(
f"最佳结果:第 {bandit_state.get('best_index')} 轮,奖励 {best_payload['reward']:+.4f}"
)
col_best1, col_best2 = st.columns(2)
col_best1.write("动作(raw)")
col_best1.json(best_payload.get("action") or {})
col_best2.write("参数值:")
col_best2.json(best_payload.get("resolved_action") or {})
weights_payload = best_payload.get("weights") or {}
if weights_payload:
st.write("对应代理权重:")
st.json(weights_payload)
if st.button(
"将最佳权重写入默认配置",
key="save_decision_env_bandit_weights",
):
try:
cfg.agent_weights.update_from_dict(weights_payload)
save_config(cfg)
except Exception as exc: # noqa: BLE001
LOGGER.exception(
"保存 bandit 权重失败",
extra={**LOG_EXTRA, "error": str(exc)},
)
st.error(f"写入配置失败:{exc}")
else:
st.success("最佳权重已写入 config.json")
dept_ctrl = best_payload.get("department_controls") or {}
if dept_ctrl:
with st.expander("最佳部门控制参数", expanded=False):
st.json(dept_ctrl)
if st.button("清除自动探索结果", key="clear_decision_env_bandit"):
st.session_state.pop(_DECISION_ENV_BANDIT_RESULTS_KEY, None)
st.success("已清除自动探索结果。")
st.divider()
st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。")
default_grid = "\n".join(
[
",".join(["0.2" for _ in specs]),
",".join(["0.5" for _ in specs]),
",".join(["0.8" for _ in specs]),
]
) if specs else ""
action_grid_raw = st.text_area(
"动作列表",
value=default_grid,
height=120,
key="decision_env_batch_actions",
)
run_batch = st.button("批量执行调参", key="run_decision_env_batch")
batch_just_ran = False
if run_batch:
if not specs:
st.warning("请至少配置一个动作维度。")
elif selected_agents and not range_valid:
st.error("请确保所有代理的最大权重大于最小权重。")
elif not controls_valid:
st.error("请修正部门参数的取值范围。")
else:
LOGGER.info(
"离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s",
selected_agents,
disable_departments,
extra=LOG_EXTRA,
)
lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()]
if not lines:
st.warning("请在文本框中输入至少一组动作。")
else:
LOGGER.debug(
"离线调参(批量)原始输入=%s",
lines,
extra=LOG_EXTRA,
)
parsed_actions: List[List[float]] = []
for line in lines:
try:
values = [float(val.strip()) for val in line.split(',') if val.strip()]
except ValueError:
st.error(f"无法解析动作行:{line}")
parsed_actions = []
break
if len(values) != len(specs):
st.error(f"动作维度不匹配(期望 {len(specs)} 个值):{line}")
parsed_actions = []
break
parsed_actions.append(values)
if parsed_actions:
LOGGER.info(
"离线调参(批量)解析动作成功,数量=%s",
len(parsed_actions),
extra=LOG_EXTRA,
)
baseline_weights = cfg.agent_weights.as_dict()
for agent in agent_objects:
baseline_weights.setdefault(agent.name, 1.0)
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
if not universe_env:
st.error("请先指定至少一个股票代码。")
else:
bt_cfg_env = BtConfig(
id="decision_env_streamlit_batch",
name="DecisionEnv Batch",
start_date=start_date,
end_date=end_date,
universe=universe_env,
params={
"target": target,
"stop": stop,
"hold_days": int(hold_days),
},
method=cfg.decision_method,
)
env = DecisionEnv(
bt_config=bt_cfg_env,
parameter_specs=specs,
baseline_weights=baseline_weights,
disable_departments=disable_departments,
)
results: List[Dict[str, object]] = []
resolved_experiment_id = experiment_id or str(uuid.uuid4())
resolved_strategy = strategy_label or "DecisionEnv"
LOGGER.debug(
"离线调参(批量)启动 DecisionEnvcfg=%s 动作组=%s",
bt_cfg_env,
len(parsed_actions),
extra=LOG_EXTRA,
)
with st.spinner("正在批量执行调参……"):
for idx, action_vals in enumerate(parsed_actions, start=1):
env.reset()
try:
observation, reward, done, info = env.step(action_vals)
except Exception as exc: # noqa: BLE001
LOGGER.exception("批量调参失败", extra=LOG_EXTRA)
results.append(
{
"序号": idx,
"动作": action_vals,
"状态": "error",
"错误": str(exc),
}
)
continue
if observation.get("failure"):
results.append(
{
"序号": idx,
"动作": action_vals,
"状态": "failure",
"奖励": -1.0,
}
)
else:
LOGGER.info(
"离线调参(批量)第 %s 组完成reward=%.4f obs=%s",
idx,
reward,
observation,
extra=LOG_EXTRA,
)
action_payload = {
label: value
for label, value in zip(spec_labels, action_vals)
}
metrics_payload = dict(observation)
metrics_payload["reward"] = reward
metrics_payload["department_controls"] = info.get("department_controls")
weights_payload = info.get("weights", {})
try:
log_tuning_result(
experiment_id=resolved_experiment_id,
strategy=resolved_strategy,
action=action_payload,
reward=reward,
metrics=metrics_payload,
weights=weights_payload,
)
except Exception: # noqa: BLE001
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
results.append(
{
"序号": idx,
"动作": action_payload,
"状态": "ok",
"总收益": observation.get("total_return", 0.0),
"最大回撤": observation.get("max_drawdown", 0.0),
"波动率": observation.get("volatility", 0.0),
"奖励": reward,
"权重": weights_payload,
"部门控制": info.get("department_controls"),
}
)
st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = {
"results": results,
"selectable": [
row
for row in results
if row.get("状态") == "ok" and row.get("权重")
],
"experiment_id": resolved_experiment_id,
"strategy_label": resolved_strategy,
}
batch_just_ran = True
LOGGER.info(
"离线调参(批量)执行结束,总结果条数=%s",
len(results),
extra=LOG_EXTRA,
)
batch_state = st.session_state.get(_DECISION_ENV_BATCH_RESULTS_KEY)
if batch_state:
results = batch_state.get("results") or []
if results:
if batch_just_ran:
st.success("批量调参完成")
else:
st.success("批量调参结果(最近一次运行)")
st.caption(
f"实验 ID{batch_state.get('experiment_id', '-') } | 策略:{batch_state.get('strategy_label', 'DecisionEnv')}"
)
results_df = pd.DataFrame(results)
st.write("批量调参结果:")
st.dataframe(results_df, hide_index=True, width='stretch')
selectable = batch_state.get("selectable") or []
if selectable:
option_labels = [
f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}"
for row in selectable
]
selected_label = st.selectbox(
"选择要保存的记录",
option_labels,
key="decision_env_batch_select",
)
selected_row = None
for label, row in zip(option_labels, selectable):
if label == selected_label:
selected_row = row
break
if selected_row and st.button(
"保存所选权重为默认配置",
key="save_decision_env_weights_batch",
):
try:
cfg.agent_weights.update_from_dict(selected_row.get("权重", {}))
save_config(cfg)
except Exception as exc: # noqa: BLE001
LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"写入配置失败:{exc}")
else:
st.success(
f"已将序号 {selected_row['序号']} 的权重写入 config.json"
)
else:
st.caption("暂无成功的结果可供保存。")
else:
st.caption("批量调参在最近一次执行中未产生结果。")
if st.button("清除批量调参结果", key="clear_decision_env_batch"):
st.session_state.pop(_DECISION_ENV_BATCH_RESULTS_KEY, None)
st.session_state.pop("decision_env_batch_select", None)
st.success("已清除批量调参结果缓存。")