update
This commit is contained in:
parent
bd1004c384
commit
a492d6a9f7
@ -1260,7 +1260,8 @@ def render_log_viewer() -> None:
|
|||||||
def render_backtest_review() -> None:
|
def render_backtest_review() -> None:
|
||||||
"""渲染回测执行、调参与结果复盘页面。"""
|
"""渲染回测执行、调参与结果复盘页面。"""
|
||||||
st.header("回测与复盘")
|
st.header("回测与复盘")
|
||||||
cfg = get_config()
|
st.caption("1. 基于历史数据复盘当前策略;2. 借助强化学习/调参探索更优参数组合。")
|
||||||
|
app_cfg = get_config()
|
||||||
default_start, default_end = _default_backtest_range(window_days=60)
|
default_start, default_end = _default_backtest_range(window_days=60)
|
||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
|
"回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
|
||||||
@ -1273,13 +1274,15 @@ def render_backtest_review() -> None:
|
|||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
st.markdown("### 回测参数")
|
||||||
col1, col2 = st.columns(2)
|
col1, col2 = st.columns(2)
|
||||||
start_date = col1.date_input("开始日期", value=default_start)
|
start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date")
|
||||||
end_date = col2.date_input("结束日期", value=default_end)
|
end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date")
|
||||||
universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ")
|
universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe")
|
||||||
target = st.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f")
|
col_target, col_stop, col_hold = st.columns(3)
|
||||||
stop = st.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f")
|
target = col_target.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f", key="bt_target")
|
||||||
hold_days = st.number_input("持有期(交易日)", value=10, step=1)
|
stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f", key="bt_stop")
|
||||||
|
hold_days = col_hold.number_input("持有期(交易日)", value=10, step=1, key="bt_hold_days")
|
||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"当前回测表单输入:start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s",
|
"当前回测表单输入:start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s",
|
||||||
start_date,
|
start_date,
|
||||||
@ -1291,7 +1294,11 @@ def render_backtest_review() -> None:
|
|||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
|
|
||||||
if st.button("运行回测"):
|
tab_backtest, tab_rl = st.tabs(["回测验证", "强化学习调参"])
|
||||||
|
|
||||||
|
with tab_backtest:
|
||||||
|
st.markdown("#### 回测执行")
|
||||||
|
if st.button("运行回测", key="bt_run_button"):
|
||||||
LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA)
|
LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA)
|
||||||
decision_log_container = st.container()
|
decision_log_container = st.container()
|
||||||
status_box = st.status("准备执行回测...", expanded=True)
|
status_box = st.status("准备执行回测...", expanded=True)
|
||||||
@ -1346,7 +1353,7 @@ def render_backtest_review() -> None:
|
|||||||
hold_days,
|
hold_days,
|
||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
cfg = BtConfig(
|
backtest_cfg = BtConfig(
|
||||||
id="streamlit_demo",
|
id="streamlit_demo",
|
||||||
name="Streamlit Demo Strategy",
|
name="Streamlit Demo Strategy",
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
@ -1358,7 +1365,7 @@ def render_backtest_review() -> None:
|
|||||||
"hold_days": int(hold_days),
|
"hold_days": int(hold_days),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
result = run_backtest(cfg, decision_callback=_decision_callback)
|
result = run_backtest(backtest_cfg, decision_callback=_decision_callback)
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
"回测完成:nav_records=%s trades=%s",
|
"回测完成:nav_records=%s trades=%s",
|
||||||
len(result.nav_series),
|
len(result.nav_series),
|
||||||
@ -1378,12 +1385,136 @@ def render_backtest_review() -> None:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
_update_dashboard_sidebar(metrics)
|
_update_dashboard_sidebar(metrics)
|
||||||
st.json({"nav_records": result.nav_series, "trades": result.trades})
|
st.session_state["backtest_last_result"] = {"nav_records": result.nav_series, "trades": result.trades}
|
||||||
|
st.json(st.session_state["backtest_last_result"])
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
LOGGER.exception("回测执行失败", extra=LOG_EXTRA)
|
LOGGER.exception("回测执行失败", extra=LOG_EXTRA)
|
||||||
status_box.update(label="回测执行失败", state="error")
|
status_box.update(label="回测执行失败", state="error")
|
||||||
st.error(f"回测执行失败:{exc}")
|
st.error(f"回测执行失败:{exc}")
|
||||||
|
|
||||||
|
last_result = st.session_state.get("backtest_last_result")
|
||||||
|
if last_result:
|
||||||
|
st.markdown("#### 最近回测输出")
|
||||||
|
st.json(last_result)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
# ADD: Comparison view for multiple backtest configurations
|
||||||
|
with st.expander("历史回测结果对比", expanded=False):
|
||||||
|
st.caption("从历史回测配置中选择多个进行净值曲线与指标对比。")
|
||||||
|
normalize_to_one = st.checkbox("归一化到 1 起点", value=True, key="bt_cmp_normalize")
|
||||||
|
use_log_y = st.checkbox("对数坐标", value=False, key="bt_cmp_log_y")
|
||||||
|
metric_options = ["总收益", "最大回撤", "交易数", "平均换手", "风险事件"]
|
||||||
|
selected_metrics = st.multiselect("显示指标列", metric_options, default=metric_options, key="bt_cmp_metrics")
|
||||||
|
try:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
cfg_rows = conn.execute(
|
||||||
|
"SELECT id, name FROM bt_config ORDER BY rowid DESC LIMIT 50"
|
||||||
|
).fetchall()
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.exception("读取 bt_config 失败", extra=LOG_EXTRA)
|
||||||
|
cfg_rows = []
|
||||||
|
cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows]
|
||||||
|
selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2], key="bt_cmp_configs")
|
||||||
|
selected_ids = [label.split(" | ")[0].strip() for label in selected_labels]
|
||||||
|
nav_df = pd.DataFrame()
|
||||||
|
rpt_df = pd.DataFrame()
|
||||||
|
if selected_ids:
|
||||||
|
try:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
nav_df = pd.read_sql_query(
|
||||||
|
"SELECT cfg_id, trade_date, nav FROM bt_nav WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
|
||||||
|
conn,
|
||||||
|
params=tuple(selected_ids),
|
||||||
|
)
|
||||||
|
rpt_df = pd.read_sql_query(
|
||||||
|
"SELECT cfg_id, summary FROM bt_report WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
|
||||||
|
conn,
|
||||||
|
params=tuple(selected_ids),
|
||||||
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA)
|
||||||
|
st.error("读取回测结果失败")
|
||||||
|
nav_df = pd.DataFrame()
|
||||||
|
rpt_df = pd.DataFrame()
|
||||||
|
if not nav_df.empty:
|
||||||
|
try:
|
||||||
|
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce")
|
||||||
|
# ADD: date window filter
|
||||||
|
overall_min = pd.to_datetime(nav_df["trade_date"].min()).date()
|
||||||
|
overall_max = pd.to_datetime(nav_df["trade_date"].max()).date()
|
||||||
|
col_d1, col_d2 = st.columns(2)
|
||||||
|
start_filter = col_d1.date_input("起始日期", value=overall_min, key="bt_cmp_start")
|
||||||
|
end_filter = col_d2.date_input("结束日期", value=overall_max, key="bt_cmp_end")
|
||||||
|
if start_filter > end_filter:
|
||||||
|
start_filter, end_filter = end_filter, start_filter
|
||||||
|
mask = (nav_df["trade_date"].dt.date >= start_filter) & (nav_df["trade_date"].dt.date <= end_filter)
|
||||||
|
nav_df = nav_df.loc[mask]
|
||||||
|
pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav")
|
||||||
|
if normalize_to_one:
|
||||||
|
pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s)
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
fig = go.Figure()
|
||||||
|
for col in pivot.columns:
|
||||||
|
fig.add_trace(go.Scatter(x=pivot.index, y=pivot[col], mode="lines", name=str(col)))
|
||||||
|
fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10))
|
||||||
|
if use_log_y:
|
||||||
|
fig.update_yaxes(type="log")
|
||||||
|
st.plotly_chart(fig, width='stretch')
|
||||||
|
# ADD: export pivot
|
||||||
|
try:
|
||||||
|
csv_buf = pivot.reset_index()
|
||||||
|
csv_buf.columns = ["trade_date"] + [str(c) for c in pivot.columns]
|
||||||
|
st.download_button(
|
||||||
|
"下载曲线(CSV)",
|
||||||
|
data=csv_buf.to_csv(index=False),
|
||||||
|
file_name="bt_nav_compare.csv",
|
||||||
|
mime="text/csv",
|
||||||
|
key="dl_nav_compare",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA)
|
||||||
|
if not rpt_df.empty:
|
||||||
|
try:
|
||||||
|
metrics_rows: List[Dict[str, object]] = []
|
||||||
|
for _, row in rpt_df.iterrows():
|
||||||
|
cfg_id = row["cfg_id"]
|
||||||
|
try:
|
||||||
|
summary = json.loads(row["summary"]) if isinstance(row["summary"], str) else (row["summary"] or {})
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
summary = {}
|
||||||
|
record = {
|
||||||
|
"cfg_id": cfg_id,
|
||||||
|
"总收益": summary.get("total_return"),
|
||||||
|
"最大回撤": summary.get("max_drawdown"),
|
||||||
|
"交易数": summary.get("trade_count"),
|
||||||
|
"平均换手": summary.get("avg_turnover"),
|
||||||
|
"风险事件": summary.get("risk_events"),
|
||||||
|
}
|
||||||
|
metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)})
|
||||||
|
if metrics_rows:
|
||||||
|
dfm = pd.DataFrame(metrics_rows)
|
||||||
|
st.dataframe(dfm, hide_index=True, width='stretch')
|
||||||
|
try:
|
||||||
|
st.download_button(
|
||||||
|
"下载指标(CSV)",
|
||||||
|
data=dfm.to_csv(index=False),
|
||||||
|
file_name="bt_metrics_compare.csv",
|
||||||
|
mime="text/csv",
|
||||||
|
key="dl_metrics_compare",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA)
|
||||||
|
else:
|
||||||
|
st.info("请选择至少一个配置进行对比。")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
with tab_rl:
|
||||||
|
st.caption("使用 DecisionEnv 对代理权重进行强化学习调参,支持单次与批量实验。")
|
||||||
with st.expander("离线调参实验 (DecisionEnv)", expanded=False):
|
with st.expander("离线调参实验 (DecisionEnv)", expanded=False):
|
||||||
st.caption(
|
st.caption(
|
||||||
"使用 DecisionEnv 对代理权重做离线调参。请选择需要优化的代理并设定权重范围,"
|
"使用 DecisionEnv 对代理权重做离线调参。请选择需要优化的代理并设定权重范围,"
|
||||||
@ -1881,119 +2012,6 @@ def render_backtest_review() -> None:
|
|||||||
st.session_state.pop("decision_env_batch_select", None)
|
st.session_state.pop("decision_env_batch_select", None)
|
||||||
st.success("已清除批量调参结果缓存。")
|
st.success("已清除批量调参结果缓存。")
|
||||||
|
|
||||||
# ADD: Comparison view for multiple backtest configurations
|
|
||||||
with st.expander("回测结果对比", expanded=False):
|
|
||||||
st.caption("从历史回测配置中选择多个进行净值曲线与指标对比。")
|
|
||||||
normalize_to_one = st.checkbox("归一化到 1 起点", value=True)
|
|
||||||
use_log_y = st.checkbox("对数坐标", value=False)
|
|
||||||
metric_options = ["总收益", "最大回撤", "交易数", "平均换手", "风险事件"]
|
|
||||||
selected_metrics = st.multiselect("显示指标列", metric_options, default=metric_options)
|
|
||||||
try:
|
|
||||||
with db_session(read_only=True) as conn:
|
|
||||||
cfg_rows = conn.execute(
|
|
||||||
"SELECT id, name FROM bt_config ORDER BY rowid DESC LIMIT 50"
|
|
||||||
).fetchall()
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
LOGGER.exception("读取 bt_config 失败", extra=LOG_EXTRA)
|
|
||||||
cfg_rows = []
|
|
||||||
cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows]
|
|
||||||
selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2])
|
|
||||||
selected_ids = [label.split(" | ")[0].strip() for label in selected_labels]
|
|
||||||
nav_df = pd.DataFrame()
|
|
||||||
rpt_df = pd.DataFrame()
|
|
||||||
if selected_ids:
|
|
||||||
try:
|
|
||||||
with db_session(read_only=True) as conn:
|
|
||||||
nav_df = pd.read_sql_query(
|
|
||||||
"SELECT cfg_id, trade_date, nav FROM bt_nav WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
|
|
||||||
conn,
|
|
||||||
params=tuple(selected_ids),
|
|
||||||
)
|
|
||||||
rpt_df = pd.read_sql_query(
|
|
||||||
"SELECT cfg_id, summary FROM bt_report WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))),
|
|
||||||
conn,
|
|
||||||
params=tuple(selected_ids),
|
|
||||||
)
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA)
|
|
||||||
st.error("读取回测结果失败")
|
|
||||||
nav_df = pd.DataFrame()
|
|
||||||
rpt_df = pd.DataFrame()
|
|
||||||
if not nav_df.empty:
|
|
||||||
try:
|
|
||||||
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce")
|
|
||||||
# ADD: date window filter
|
|
||||||
overall_min = pd.to_datetime(nav_df["trade_date"].min()).date()
|
|
||||||
overall_max = pd.to_datetime(nav_df["trade_date"].max()).date()
|
|
||||||
col_d1, col_d2 = st.columns(2)
|
|
||||||
start_filter = col_d1.date_input("起始日期", value=overall_min)
|
|
||||||
end_filter = col_d2.date_input("结束日期", value=overall_max)
|
|
||||||
if start_filter > end_filter:
|
|
||||||
start_filter, end_filter = end_filter, start_filter
|
|
||||||
mask = (nav_df["trade_date"].dt.date >= start_filter) & (nav_df["trade_date"].dt.date <= end_filter)
|
|
||||||
nav_df = nav_df.loc[mask]
|
|
||||||
pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav")
|
|
||||||
if normalize_to_one:
|
|
||||||
pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s)
|
|
||||||
import plotly.graph_objects as go
|
|
||||||
fig = go.Figure()
|
|
||||||
for col in pivot.columns:
|
|
||||||
fig.add_trace(go.Scatter(x=pivot.index, y=pivot[col], mode="lines", name=str(col)))
|
|
||||||
fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10))
|
|
||||||
if use_log_y:
|
|
||||||
fig.update_yaxes(type="log")
|
|
||||||
st.plotly_chart(fig, width='stretch')
|
|
||||||
# ADD: export pivot
|
|
||||||
try:
|
|
||||||
csv_buf = pivot.reset_index()
|
|
||||||
csv_buf.columns = ["trade_date"] + [str(c) for c in pivot.columns]
|
|
||||||
st.download_button(
|
|
||||||
"下载曲线(CSV)",
|
|
||||||
data=csv_buf.to_csv(index=False),
|
|
||||||
file_name="bt_nav_compare.csv",
|
|
||||||
mime="text/csv",
|
|
||||||
key="dl_nav_compare",
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA)
|
|
||||||
if not rpt_df.empty:
|
|
||||||
try:
|
|
||||||
metrics_rows: List[Dict[str, object]] = []
|
|
||||||
for _, row in rpt_df.iterrows():
|
|
||||||
cfg_id = row["cfg_id"]
|
|
||||||
try:
|
|
||||||
summary = json.loads(row["summary"]) if isinstance(row["summary"], str) else (row["summary"] or {})
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
summary = {}
|
|
||||||
record = {
|
|
||||||
"cfg_id": cfg_id,
|
|
||||||
"总收益": summary.get("total_return"),
|
|
||||||
"最大回撤": summary.get("max_drawdown"),
|
|
||||||
"交易数": summary.get("trade_count"),
|
|
||||||
"平均换手": summary.get("avg_turnover"),
|
|
||||||
"风险事件": summary.get("risk_events"),
|
|
||||||
}
|
|
||||||
metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)})
|
|
||||||
if metrics_rows:
|
|
||||||
dfm = pd.DataFrame(metrics_rows)
|
|
||||||
st.dataframe(dfm, hide_index=True, width='stretch')
|
|
||||||
try:
|
|
||||||
st.download_button(
|
|
||||||
"下载指标(CSV)",
|
|
||||||
data=dfm.to_csv(index=False),
|
|
||||||
file_name="bt_metrics_compare.csv",
|
|
||||||
mime="text/csv",
|
|
||||||
key="dl_metrics_compare",
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA)
|
|
||||||
else:
|
|
||||||
st.info("请选择至少一个配置进行对比。")
|
|
||||||
|
|
||||||
|
|
||||||
def render_config_overview() -> None:
|
def render_config_overview() -> None:
|
||||||
"""Render a concise overview of persisted configuration values."""
|
"""Render a concise overview of persisted configuration values."""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user