update
This commit is contained in:
parent
e234d66687
commit
dc2d82f685
@ -370,7 +370,9 @@ def render_backtest_review() -> None:
|
||||
action_values.append(action_val)
|
||||
|
||||
controls_valid = True
|
||||
with st.expander("部门 LLM 参数", expanded=False):
|
||||
|
||||
st.divider()
|
||||
st.subheader("部门参数调整(可选)")
|
||||
dept_codes = sorted(app_cfg.departments.keys())
|
||||
if not dept_codes:
|
||||
st.caption("当前未配置部门。")
|
||||
@ -509,213 +511,6 @@ def render_backtest_review() -> None:
|
||||
help="关闭部门调用后不依赖外部 LLM 网络,仅根据规则代理权重模拟。",
|
||||
)
|
||||
|
||||
st.divider()
|
||||
st.subheader("单次调参")
|
||||
|
||||
run_decision_env = st.button("执行单次调参", key="run_decision_env_button")
|
||||
just_finished_single = False
|
||||
if run_decision_env:
|
||||
if not specs:
|
||||
st.warning("请至少配置一个动作维度(代理或部门参数)。")
|
||||
elif selected_agents and not range_valid:
|
||||
st.error("请确保所有代理的最大权重大于最小权重。")
|
||||
elif not controls_valid:
|
||||
st.error("请修正部门参数的取值范围。")
|
||||
else:
|
||||
LOGGER.info(
|
||||
"离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s",
|
||||
selected_agents,
|
||||
action_values,
|
||||
disable_departments,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
baseline_weights = app_cfg.agent_weights.as_dict()
|
||||
for agent in agent_objects:
|
||||
baseline_weights.setdefault(agent.name, 1.0)
|
||||
|
||||
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||
if not universe_env:
|
||||
st.error("请先指定至少一个股票代码。")
|
||||
else:
|
||||
bt_cfg_env = BtConfig(
|
||||
id="decision_env_streamlit",
|
||||
name="DecisionEnv Streamlit",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
universe=universe_env,
|
||||
params={
|
||||
"target": target,
|
||||
"stop": stop,
|
||||
"hold_days": int(hold_days),
|
||||
},
|
||||
method=app_cfg.decision_method,
|
||||
)
|
||||
env = DecisionEnv(
|
||||
bt_config=bt_cfg_env,
|
||||
parameter_specs=specs,
|
||||
baseline_weights=baseline_weights,
|
||||
disable_departments=disable_departments,
|
||||
)
|
||||
env.reset()
|
||||
LOGGER.debug(
|
||||
"离线调参(单次)启动 DecisionEnv:cfg=%s 参数维度=%s",
|
||||
bt_cfg_env,
|
||||
len(specs),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
with st.spinner("正在执行离线调参……"):
|
||||
try:
|
||||
observation, reward, done, info = env.step(action_values)
|
||||
LOGGER.info(
|
||||
"离线调参(单次)完成,obs=%s reward=%.4f done=%s",
|
||||
observation,
|
||||
reward,
|
||||
done,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA)
|
||||
st.error(f"离线调参失败:{exc}")
|
||||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||||
else:
|
||||
if observation.get("failure"):
|
||||
st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。")
|
||||
st.json(observation)
|
||||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||||
else:
|
||||
resolved_experiment_id = experiment_id or str(uuid.uuid4())
|
||||
resolved_strategy = strategy_label or "DecisionEnv"
|
||||
action_payload = {
|
||||
label: value for label, value in zip(spec_labels, action_values)
|
||||
}
|
||||
metrics_payload = dict(observation)
|
||||
metrics_payload["reward"] = reward
|
||||
metrics_payload["department_controls"] = info.get("department_controls")
|
||||
log_success = False
|
||||
try:
|
||||
log_tuning_result(
|
||||
experiment_id=resolved_experiment_id,
|
||||
strategy=resolved_strategy,
|
||||
action=action_payload,
|
||||
reward=reward,
|
||||
metrics=metrics_payload,
|
||||
weights=info.get("weights", {}),
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
|
||||
else:
|
||||
log_success = True
|
||||
LOGGER.info(
|
||||
"离线调参(单次)日志写入成功:experiment=%s strategy=%s",
|
||||
resolved_experiment_id,
|
||||
resolved_strategy,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
st.session_state[_DECISION_ENV_SINGLE_RESULT_KEY] = {
|
||||
"observation": dict(observation),
|
||||
"reward": float(reward),
|
||||
"weights": info.get("weights", {}),
|
||||
"department_controls": info.get("department_controls"),
|
||||
"actions": action_payload,
|
||||
"nav_series": info.get("nav_series"),
|
||||
"trades": info.get("trades"),
|
||||
"portfolio_snapshots": info.get("portfolio_snapshots"),
|
||||
"portfolio_trades": info.get("portfolio_trades"),
|
||||
"risk_breakdown": info.get("risk_breakdown"),
|
||||
"spec_labels": list(spec_labels),
|
||||
"action_values": list(action_values),
|
||||
"experiment_id": resolved_experiment_id,
|
||||
"strategy_label": resolved_strategy,
|
||||
"logged": log_success,
|
||||
}
|
||||
just_finished_single = True
|
||||
single_result = st.session_state.get(_DECISION_ENV_SINGLE_RESULT_KEY)
|
||||
if single_result:
|
||||
if just_finished_single:
|
||||
st.success("离线调参完成")
|
||||
else:
|
||||
st.success("离线调参结果(最近一次运行)")
|
||||
st.caption(
|
||||
f"实验 ID:{single_result.get('experiment_id', '-') } | 策略:{single_result.get('strategy_label', 'DecisionEnv')}"
|
||||
)
|
||||
observation = single_result.get("observation", {})
|
||||
reward = float(single_result.get("reward", 0.0))
|
||||
col_metrics = st.columns(4)
|
||||
col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}")
|
||||
col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}")
|
||||
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
|
||||
col_metrics[3].metric("奖励", f"{reward:+.4f}")
|
||||
|
||||
turnover_ratio = float(observation.get("turnover", 0.0) or 0.0)
|
||||
turnover_value = float(observation.get("turnover_value", 0.0) or 0.0)
|
||||
risk_count = float(observation.get("risk_count", 0.0) or 0.0)
|
||||
col_metrics_extra = st.columns(3)
|
||||
col_metrics_extra[0].metric("平均换手率", f"{turnover_ratio:.2%}")
|
||||
col_metrics_extra[1].metric("成交额", f"{turnover_value:,.0f}")
|
||||
col_metrics_extra[2].metric("风险事件数", f"{int(risk_count)}")
|
||||
|
||||
weights_dict = single_result.get("weights") or {}
|
||||
if weights_dict:
|
||||
st.write("调参后权重:")
|
||||
st.json(weights_dict)
|
||||
if st.button("保存这些权重为默认配置", key="save_decision_env_weights_single"):
|
||||
try:
|
||||
app_cfg.agent_weights.update_from_dict(weights_dict)
|
||||
save_config(app_cfg)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
|
||||
st.error(f"写入配置失败:{exc}")
|
||||
else:
|
||||
st.success("代理权重已写入 config.json")
|
||||
|
||||
if single_result.get("logged"):
|
||||
st.caption("调参结果已写入 tuning_results 表。")
|
||||
|
||||
nav_series = single_result.get("nav_series") or []
|
||||
if nav_series:
|
||||
try:
|
||||
nav_df = pd.DataFrame(nav_series)
|
||||
if {"trade_date", "nav"}.issubset(nav_df.columns):
|
||||
nav_df = nav_df.sort_values("trade_date")
|
||||
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"])
|
||||
st.line_chart(nav_df.set_index("trade_date")["nav"], height=220)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA)
|
||||
|
||||
trades = single_result.get("trades") or []
|
||||
if trades:
|
||||
st.write("成交记录:")
|
||||
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
|
||||
|
||||
snapshots = single_result.get("portfolio_snapshots") or []
|
||||
if snapshots:
|
||||
with st.expander("投资组合快照", expanded=False):
|
||||
st.dataframe(pd.DataFrame(snapshots), hide_index=True, width='stretch')
|
||||
|
||||
portfolio_trades = single_result.get("portfolio_trades") or []
|
||||
if portfolio_trades:
|
||||
with st.expander("组合成交明细", expanded=False):
|
||||
st.dataframe(pd.DataFrame(portfolio_trades), hide_index=True, width='stretch')
|
||||
|
||||
risk_breakdown = single_result.get("risk_breakdown") or {}
|
||||
if risk_breakdown:
|
||||
with st.expander("风险事件统计", expanded=False):
|
||||
st.json(risk_breakdown)
|
||||
|
||||
department_info = single_result.get("department_controls") or {}
|
||||
if department_info:
|
||||
with st.expander("部门控制参数", expanded=False):
|
||||
st.json(department_info)
|
||||
|
||||
action_snapshot = single_result.get("actions") or {}
|
||||
if action_snapshot:
|
||||
with st.expander("动作明细", expanded=False):
|
||||
st.json(action_snapshot)
|
||||
|
||||
if st.button("清除单次调参结果", key="clear_decision_env_single"):
|
||||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||||
st.success("已清除单次调参结果缓存。")
|
||||
|
||||
st.divider()
|
||||
st.subheader("自动探索(epsilon-greedy)")
|
||||
col_ep, col_eps, col_seed = st.columns([1, 1, 1])
|
||||
@ -880,232 +675,4 @@ def render_backtest_review() -> None:
|
||||
st.session_state.pop(_DECISION_ENV_BANDIT_RESULTS_KEY, None)
|
||||
st.success("已清除自动探索结果。")
|
||||
|
||||
st.divider()
|
||||
st.subheader("批量调参")
|
||||
st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。")
|
||||
default_grid = "\n".join(
|
||||
[
|
||||
",".join(["0.2" for _ in specs]),
|
||||
",".join(["0.5" for _ in specs]),
|
||||
",".join(["0.8" for _ in specs]),
|
||||
]
|
||||
) if specs else ""
|
||||
action_grid_raw = st.text_area(
|
||||
"动作列表",
|
||||
value=default_grid,
|
||||
height=120,
|
||||
key="decision_env_batch_actions",
|
||||
)
|
||||
run_batch = st.button("批量执行调参", key="run_decision_env_batch")
|
||||
batch_just_ran = False
|
||||
if run_batch:
|
||||
if not specs:
|
||||
st.warning("请至少配置一个动作维度。")
|
||||
elif selected_agents and not range_valid:
|
||||
st.error("请确保所有代理的最大权重大于最小权重。")
|
||||
elif not controls_valid:
|
||||
st.error("请修正部门参数的取值范围。")
|
||||
else:
|
||||
LOGGER.info(
|
||||
"离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s",
|
||||
selected_agents,
|
||||
disable_departments,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()]
|
||||
if not lines:
|
||||
st.warning("请在文本框中输入至少一组动作。")
|
||||
else:
|
||||
LOGGER.debug(
|
||||
"离线调参(批量)原始输入=%s",
|
||||
lines,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
parsed_actions: List[List[float]] = []
|
||||
for line in lines:
|
||||
try:
|
||||
values = [float(val.strip()) for val in line.split(',') if val.strip()]
|
||||
except ValueError:
|
||||
st.error(f"无法解析动作行:{line}")
|
||||
parsed_actions = []
|
||||
break
|
||||
if len(values) != len(specs):
|
||||
st.error(f"动作维度不匹配(期望 {len(specs)} 个值):{line}")
|
||||
parsed_actions = []
|
||||
break
|
||||
parsed_actions.append(values)
|
||||
if parsed_actions:
|
||||
LOGGER.info(
|
||||
"离线调参(批量)解析动作成功,数量=%s",
|
||||
len(parsed_actions),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
baseline_weights = app_cfg.agent_weights.as_dict()
|
||||
for agent in agent_objects:
|
||||
baseline_weights.setdefault(agent.name, 1.0)
|
||||
|
||||
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||
if not universe_env:
|
||||
st.error("请先指定至少一个股票代码。")
|
||||
else:
|
||||
bt_cfg_env = BtConfig(
|
||||
id="decision_env_streamlit_batch",
|
||||
name="DecisionEnv Batch",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
universe=universe_env,
|
||||
params={
|
||||
"target": target,
|
||||
"stop": stop,
|
||||
"hold_days": int(hold_days),
|
||||
},
|
||||
method=app_cfg.decision_method,
|
||||
)
|
||||
env = DecisionEnv(
|
||||
bt_config=bt_cfg_env,
|
||||
parameter_specs=specs,
|
||||
baseline_weights=baseline_weights,
|
||||
disable_departments=disable_departments,
|
||||
)
|
||||
results: List[Dict[str, object]] = []
|
||||
resolved_experiment_id = experiment_id or str(uuid.uuid4())
|
||||
resolved_strategy = strategy_label or "DecisionEnv"
|
||||
LOGGER.debug(
|
||||
"离线调参(批量)启动 DecisionEnv:cfg=%s 动作组=%s",
|
||||
bt_cfg_env,
|
||||
len(parsed_actions),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
with st.spinner("正在批量执行调参……"):
|
||||
for idx, action_vals in enumerate(parsed_actions, start=1):
|
||||
env.reset()
|
||||
try:
|
||||
observation, reward, done, info = env.step(action_vals)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("批量调参失败", extra=LOG_EXTRA)
|
||||
results.append(
|
||||
{
|
||||
"序号": idx,
|
||||
"动作": action_vals,
|
||||
"状态": "error",
|
||||
"错误": str(exc),
|
||||
}
|
||||
)
|
||||
continue
|
||||
if observation.get("failure"):
|
||||
results.append(
|
||||
{
|
||||
"序号": idx,
|
||||
"动作": action_vals,
|
||||
"状态": "failure",
|
||||
"奖励": -1.0,
|
||||
}
|
||||
)
|
||||
else:
|
||||
LOGGER.info(
|
||||
"离线调参(批量)第 %s 组完成,reward=%.4f obs=%s",
|
||||
idx,
|
||||
reward,
|
||||
observation,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
action_payload = {
|
||||
label: value
|
||||
for label, value in zip(spec_labels, action_vals)
|
||||
}
|
||||
metrics_payload = dict(observation)
|
||||
metrics_payload["reward"] = reward
|
||||
metrics_payload["department_controls"] = info.get("department_controls")
|
||||
weights_payload = info.get("weights", {})
|
||||
try:
|
||||
log_tuning_result(
|
||||
experiment_id=resolved_experiment_id,
|
||||
strategy=resolved_strategy,
|
||||
action=action_payload,
|
||||
reward=reward,
|
||||
metrics=metrics_payload,
|
||||
weights=weights_payload,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
|
||||
results.append(
|
||||
{
|
||||
"序号": idx,
|
||||
"动作": action_payload,
|
||||
"状态": "ok",
|
||||
"总收益": observation.get("total_return", 0.0),
|
||||
"最大回撤": observation.get("max_drawdown", 0.0),
|
||||
"波动率": observation.get("volatility", 0.0),
|
||||
"奖励": reward,
|
||||
"权重": weights_payload,
|
||||
"部门控制": info.get("department_controls"),
|
||||
}
|
||||
)
|
||||
st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = {
|
||||
"results": results,
|
||||
"selectable": [
|
||||
row
|
||||
for row in results
|
||||
if row.get("状态") == "ok" and row.get("权重")
|
||||
],
|
||||
"experiment_id": resolved_experiment_id,
|
||||
"strategy_label": resolved_strategy,
|
||||
}
|
||||
batch_just_ran = True
|
||||
LOGGER.info(
|
||||
"离线调参(批量)执行结束,总结果条数=%s",
|
||||
len(results),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
batch_state = st.session_state.get(_DECISION_ENV_BATCH_RESULTS_KEY)
|
||||
if batch_state:
|
||||
results = batch_state.get("results") or []
|
||||
if results:
|
||||
if batch_just_ran:
|
||||
st.success("批量调参完成")
|
||||
else:
|
||||
st.success("批量调参结果(最近一次运行)")
|
||||
st.caption(
|
||||
f"实验 ID:{batch_state.get('experiment_id', '-') } | 策略:{batch_state.get('strategy_label', 'DecisionEnv')}"
|
||||
)
|
||||
results_df = pd.DataFrame(results)
|
||||
st.write("批量调参结果:")
|
||||
st.dataframe(results_df, hide_index=True, width='stretch')
|
||||
selectable = batch_state.get("selectable") or []
|
||||
if selectable:
|
||||
option_labels = [
|
||||
f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}"
|
||||
for row in selectable
|
||||
]
|
||||
selected_label = st.selectbox(
|
||||
"选择要保存的记录",
|
||||
option_labels,
|
||||
key="decision_env_batch_select",
|
||||
)
|
||||
selected_row = None
|
||||
for label, row in zip(option_labels, selectable):
|
||||
if label == selected_label:
|
||||
selected_row = row
|
||||
break
|
||||
if selected_row and st.button(
|
||||
"保存所选权重为默认配置",
|
||||
key="save_decision_env_weights_batch",
|
||||
):
|
||||
try:
|
||||
app_cfg.agent_weights.update_from_dict(selected_row.get("权重", {}))
|
||||
save_config(app_cfg)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
|
||||
st.error(f"写入配置失败:{exc}")
|
||||
else:
|
||||
st.success(
|
||||
f"已将序号 {selected_row['序号']} 的权重写入 config.json"
|
||||
)
|
||||
else:
|
||||
st.caption("暂无成功的结果可供保存。")
|
||||
else:
|
||||
st.caption("批量调参在最近一次执行中未产生结果。")
|
||||
if st.button("清除批量调参结果", key="clear_decision_env_batch"):
|
||||
st.session_state.pop(_DECISION_ENV_BATCH_RESULTS_KEY, None)
|
||||
st.session_state.pop("decision_env_batch_select", None)
|
||||
st.success("已清除批量调参结果缓存。")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user