This commit is contained in:
sam 2025-10-06 15:46:51 +08:00
parent e234d66687
commit dc2d82f685

View File

@ -370,135 +370,137 @@ def render_backtest_review() -> None:
action_values.append(action_val)
controls_valid = True
with st.expander("部门 LLM 参数", expanded=False):
dept_codes = sorted(app_cfg.departments.keys())
if not dept_codes:
st.caption("当前未配置部门。")
else:
selected_departments = st.multiselect(
"选择需要调整的部门",
dept_codes,
default=[],
key="decision_env_departments",
st.divider()
st.subheader("部门参数调整(可选)")
dept_codes = sorted(app_cfg.departments.keys())
if not dept_codes:
st.caption("当前未配置部门。")
else:
selected_departments = st.multiselect(
"选择需要调整的部门",
dept_codes,
default=[],
key="decision_env_departments",
)
tool_policy_values = ["auto", "none", "required"]
for dept_code in selected_departments:
settings = app_cfg.departments.get(dept_code)
if not settings:
continue
st.subheader(f"部门:{settings.title or dept_code}")
base_temp = 0.2
if settings.llm and settings.llm.primary and settings.llm.primary.temperature is not None:
base_temp = float(settings.llm.primary.temperature)
prefix = f"decision_env_dept_{dept_code}"
col_tmin, col_tmax, col_tslider = st.columns([1, 1, 2])
temp_min = col_tmin.number_input(
"温度最小值",
min_value=0.0,
max_value=2.0,
value=max(0.0, base_temp - 0.3),
step=0.05,
key=f"{prefix}_temp_min",
)
tool_policy_values = ["auto", "none", "required"]
for dept_code in selected_departments:
settings = app_cfg.departments.get(dept_code)
if not settings:
continue
st.subheader(f"部门:{settings.title or dept_code}")
base_temp = 0.2
if settings.llm and settings.llm.primary and settings.llm.primary.temperature is not None:
base_temp = float(settings.llm.primary.temperature)
prefix = f"decision_env_dept_{dept_code}"
col_tmin, col_tmax, col_tslider = st.columns([1, 1, 2])
temp_min = col_tmin.number_input(
"温度最小值",
min_value=0.0,
max_value=2.0,
value=max(0.0, base_temp - 0.3),
step=0.05,
key=f"{prefix}_temp_min",
temp_max = col_tmax.number_input(
"温度最大值",
min_value=0.0,
max_value=2.0,
value=min(2.0, base_temp + 0.3),
step=0.05,
key=f"{prefix}_temp_max",
)
if temp_max <= temp_min:
controls_valid = False
st.warning("温度最大值必须大于最小值。")
temp_max = min(2.0, temp_min + 0.01)
span = temp_max - temp_min
if span <= 0:
ratio_default = 0.0
else:
clamped = min(max(base_temp, temp_min), temp_max)
ratio_default = (clamped - temp_min) / span
temp_action = col_tslider.slider(
"动作值(映射至温度区间)",
min_value=0.0,
max_value=1.0,
value=float(ratio_default),
step=0.01,
key=f"{prefix}_temp_action",
)
specs.append(
ParameterSpec(
name=f"dept_temperature_{dept_code}",
target=f"department.{dept_code}.temperature",
minimum=temp_min,
maximum=temp_max,
)
temp_max = col_tmax.number_input(
"温度最大值",
min_value=0.0,
max_value=2.0,
value=min(2.0, base_temp + 0.3),
step=0.05,
key=f"{prefix}_temp_max",
)
spec_labels.append(f"department:{dept_code}:temperature")
action_values.append(temp_action)
col_tool, col_hint = st.columns([1, 2])
tool_choice = col_tool.selectbox(
"函数调用策略",
tool_policy_values,
index=tool_policy_values.index("auto"),
key=f"{prefix}_tool_choice",
)
col_hint.caption("映射提示0→auto0.5→none1→required。")
if len(tool_policy_values) > 1:
tool_value = tool_policy_values.index(tool_choice) / (len(tool_policy_values) - 1)
else:
tool_value = 0.0
specs.append(
ParameterSpec(
name=f"dept_tool_{dept_code}",
target=f"department.{dept_code}.function_policy",
values=tool_policy_values,
)
if temp_max <= temp_min:
controls_valid = False
st.warning("温度最大值必须大于最小值。")
temp_max = min(2.0, temp_min + 0.01)
span = temp_max - temp_min
if span <= 0:
ratio_default = 0.0
else:
clamped = min(max(base_temp, temp_min), temp_max)
ratio_default = (clamped - temp_min) / span
temp_action = col_tslider.slider(
"动作值(映射至温度区间)",
min_value=0.0,
max_value=1.0,
value=float(ratio_default),
step=0.01,
key=f"{prefix}_temp_action",
)
spec_labels.append(f"department:{dept_code}:tool_choice")
action_values.append(tool_value)
template_id = (settings.prompt_template_id or f"{dept_code}_dept").strip()
versions = [ver for ver in TemplateRegistry.list_versions(template_id) if isinstance(ver, str)]
if versions:
active_version = TemplateRegistry.get_active_version(template_id)
default_version = (
settings.prompt_template_version
or active_version
or versions[0]
)
try:
default_index = versions.index(default_version)
except ValueError:
default_index = 0
version_choice = st.selectbox(
"提示模板版本",
versions,
index=default_index,
key=f"{prefix}_template_version",
help="离散动作将按版本列表顺序映射,可用于强化学习优化。",
)
selected_index = versions.index(version_choice)
ratio = (
0.0
if len(versions) == 1
else selected_index / (len(versions) - 1)
)
specs.append(
ParameterSpec(
name=f"dept_temperature_{dept_code}",
target=f"department.{dept_code}.temperature",
minimum=temp_min,
maximum=temp_max,
name=f"dept_prompt_version_{dept_code}",
target=f"department.{dept_code}.prompt_template_version",
values=list(versions),
)
)
spec_labels.append(f"department:{dept_code}:temperature")
action_values.append(temp_action)
col_tool, col_hint = st.columns([1, 2])
tool_choice = col_tool.selectbox(
"函数调用策略",
tool_policy_values,
index=tool_policy_values.index("auto"),
key=f"{prefix}_tool_choice",
spec_labels.append(f"department:{dept_code}:prompt_version")
action_values.append(ratio)
st.caption(
f"激活版本:{active_version or '默认'} 当前选择:{version_choice}"
)
col_hint.caption("映射提示0→auto0.5→none1→required。")
if len(tool_policy_values) > 1:
tool_value = tool_policy_values.index(tool_choice) / (len(tool_policy_values) - 1)
else:
tool_value = 0.0
specs.append(
ParameterSpec(
name=f"dept_tool_{dept_code}",
target=f"department.{dept_code}.function_policy",
values=tool_policy_values,
)
)
spec_labels.append(f"department:{dept_code}:tool_choice")
action_values.append(tool_value)
template_id = (settings.prompt_template_id or f"{dept_code}_dept").strip()
versions = [ver for ver in TemplateRegistry.list_versions(template_id) if isinstance(ver, str)]
if versions:
active_version = TemplateRegistry.get_active_version(template_id)
default_version = (
settings.prompt_template_version
or active_version
or versions[0]
)
try:
default_index = versions.index(default_version)
except ValueError:
default_index = 0
version_choice = st.selectbox(
"提示模板版本",
versions,
index=default_index,
key=f"{prefix}_template_version",
help="离散动作将按版本列表顺序映射,可用于强化学习优化。",
)
selected_index = versions.index(version_choice)
ratio = (
0.0
if len(versions) == 1
else selected_index / (len(versions) - 1)
)
specs.append(
ParameterSpec(
name=f"dept_prompt_version_{dept_code}",
target=f"department.{dept_code}.prompt_template_version",
values=list(versions),
)
)
spec_labels.append(f"department:{dept_code}:prompt_version")
action_values.append(ratio)
st.caption(
f"激活版本:{active_version or '默认'} 当前选择:{version_choice}"
)
else:
st.caption("当前模板未注册可选提示词版本,继续沿用激活版本。")
else:
st.caption("当前模板未注册可选提示词版本,继续沿用激活版本。")
if specs:
st.caption("动作维度顺序:" + "".join(spec_labels))
@ -509,213 +511,6 @@ def render_backtest_review() -> None:
help="关闭部门调用后不依赖外部 LLM 网络,仅根据规则代理权重模拟。",
)
st.divider()
st.subheader("单次调参")
run_decision_env = st.button("执行单次调参", key="run_decision_env_button")
just_finished_single = False
if run_decision_env:
if not specs:
st.warning("请至少配置一个动作维度(代理或部门参数)。")
elif selected_agents and not range_valid:
st.error("请确保所有代理的最大权重大于最小权重。")
elif not controls_valid:
st.error("请修正部门参数的取值范围。")
else:
LOGGER.info(
"离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s",
selected_agents,
action_values,
disable_departments,
extra=LOG_EXTRA,
)
baseline_weights = app_cfg.agent_weights.as_dict()
for agent in agent_objects:
baseline_weights.setdefault(agent.name, 1.0)
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
if not universe_env:
st.error("请先指定至少一个股票代码。")
else:
bt_cfg_env = BtConfig(
id="decision_env_streamlit",
name="DecisionEnv Streamlit",
start_date=start_date,
end_date=end_date,
universe=universe_env,
params={
"target": target,
"stop": stop,
"hold_days": int(hold_days),
},
method=app_cfg.decision_method,
)
env = DecisionEnv(
bt_config=bt_cfg_env,
parameter_specs=specs,
baseline_weights=baseline_weights,
disable_departments=disable_departments,
)
env.reset()
LOGGER.debug(
"离线调参(单次)启动 DecisionEnvcfg=%s 参数维度=%s",
bt_cfg_env,
len(specs),
extra=LOG_EXTRA,
)
with st.spinner("正在执行离线调参……"):
try:
observation, reward, done, info = env.step(action_values)
LOGGER.info(
"离线调参单次完成obs=%s reward=%.4f done=%s",
observation,
reward,
done,
extra=LOG_EXTRA,
)
except Exception as exc: # noqa: BLE001
LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA)
st.error(f"离线调参失败:{exc}")
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
else:
if observation.get("failure"):
st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。")
st.json(observation)
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
else:
resolved_experiment_id = experiment_id or str(uuid.uuid4())
resolved_strategy = strategy_label or "DecisionEnv"
action_payload = {
label: value for label, value in zip(spec_labels, action_values)
}
metrics_payload = dict(observation)
metrics_payload["reward"] = reward
metrics_payload["department_controls"] = info.get("department_controls")
log_success = False
try:
log_tuning_result(
experiment_id=resolved_experiment_id,
strategy=resolved_strategy,
action=action_payload,
reward=reward,
metrics=metrics_payload,
weights=info.get("weights", {}),
)
except Exception: # noqa: BLE001
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
else:
log_success = True
LOGGER.info(
"离线调参单次日志写入成功experiment=%s strategy=%s",
resolved_experiment_id,
resolved_strategy,
extra=LOG_EXTRA,
)
st.session_state[_DECISION_ENV_SINGLE_RESULT_KEY] = {
"observation": dict(observation),
"reward": float(reward),
"weights": info.get("weights", {}),
"department_controls": info.get("department_controls"),
"actions": action_payload,
"nav_series": info.get("nav_series"),
"trades": info.get("trades"),
"portfolio_snapshots": info.get("portfolio_snapshots"),
"portfolio_trades": info.get("portfolio_trades"),
"risk_breakdown": info.get("risk_breakdown"),
"spec_labels": list(spec_labels),
"action_values": list(action_values),
"experiment_id": resolved_experiment_id,
"strategy_label": resolved_strategy,
"logged": log_success,
}
just_finished_single = True
single_result = st.session_state.get(_DECISION_ENV_SINGLE_RESULT_KEY)
if single_result:
if just_finished_single:
st.success("离线调参完成")
else:
st.success("离线调参结果(最近一次运行)")
st.caption(
f"实验 ID{single_result.get('experiment_id', '-') } | 策略:{single_result.get('strategy_label', 'DecisionEnv')}"
)
observation = single_result.get("observation", {})
reward = float(single_result.get("reward", 0.0))
col_metrics = st.columns(4)
col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}")
col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}")
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
col_metrics[3].metric("奖励", f"{reward:+.4f}")
turnover_ratio = float(observation.get("turnover", 0.0) or 0.0)
turnover_value = float(observation.get("turnover_value", 0.0) or 0.0)
risk_count = float(observation.get("risk_count", 0.0) or 0.0)
col_metrics_extra = st.columns(3)
col_metrics_extra[0].metric("平均换手率", f"{turnover_ratio:.2%}")
col_metrics_extra[1].metric("成交额", f"{turnover_value:,.0f}")
col_metrics_extra[2].metric("风险事件数", f"{int(risk_count)}")
weights_dict = single_result.get("weights") or {}
if weights_dict:
st.write("调参后权重:")
st.json(weights_dict)
if st.button("保存这些权重为默认配置", key="save_decision_env_weights_single"):
try:
app_cfg.agent_weights.update_from_dict(weights_dict)
save_config(app_cfg)
except Exception as exc: # noqa: BLE001
LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"写入配置失败:{exc}")
else:
st.success("代理权重已写入 config.json")
if single_result.get("logged"):
st.caption("调参结果已写入 tuning_results 表。")
nav_series = single_result.get("nav_series") or []
if nav_series:
try:
nav_df = pd.DataFrame(nav_series)
if {"trade_date", "nav"}.issubset(nav_df.columns):
nav_df = nav_df.sort_values("trade_date")
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"])
st.line_chart(nav_df.set_index("trade_date")["nav"], height=220)
except Exception: # noqa: BLE001
LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA)
trades = single_result.get("trades") or []
if trades:
st.write("成交记录:")
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
snapshots = single_result.get("portfolio_snapshots") or []
if snapshots:
with st.expander("投资组合快照", expanded=False):
st.dataframe(pd.DataFrame(snapshots), hide_index=True, width='stretch')
portfolio_trades = single_result.get("portfolio_trades") or []
if portfolio_trades:
with st.expander("组合成交明细", expanded=False):
st.dataframe(pd.DataFrame(portfolio_trades), hide_index=True, width='stretch')
risk_breakdown = single_result.get("risk_breakdown") or {}
if risk_breakdown:
with st.expander("风险事件统计", expanded=False):
st.json(risk_breakdown)
department_info = single_result.get("department_controls") or {}
if department_info:
with st.expander("部门控制参数", expanded=False):
st.json(department_info)
action_snapshot = single_result.get("actions") or {}
if action_snapshot:
with st.expander("动作明细", expanded=False):
st.json(action_snapshot)
if st.button("清除单次调参结果", key="clear_decision_env_single"):
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
st.success("已清除单次调参结果缓存。")
st.divider()
st.subheader("自动探索epsilon-greedy")
col_ep, col_eps, col_seed = st.columns([1, 1, 1])
@ -880,232 +675,4 @@ def render_backtest_review() -> None:
st.session_state.pop(_DECISION_ENV_BANDIT_RESULTS_KEY, None)
st.success("已清除自动探索结果。")
st.divider()
st.subheader("批量调参")
st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。")
default_grid = "\n".join(
[
",".join(["0.2" for _ in specs]),
",".join(["0.5" for _ in specs]),
",".join(["0.8" for _ in specs]),
]
) if specs else ""
action_grid_raw = st.text_area(
"动作列表",
value=default_grid,
height=120,
key="decision_env_batch_actions",
)
run_batch = st.button("批量执行调参", key="run_decision_env_batch")
batch_just_ran = False
if run_batch:
if not specs:
st.warning("请至少配置一个动作维度。")
elif selected_agents and not range_valid:
st.error("请确保所有代理的最大权重大于最小权重。")
elif not controls_valid:
st.error("请修正部门参数的取值范围。")
else:
LOGGER.info(
"离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s",
selected_agents,
disable_departments,
extra=LOG_EXTRA,
)
lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()]
if not lines:
st.warning("请在文本框中输入至少一组动作。")
else:
LOGGER.debug(
"离线调参(批量)原始输入=%s",
lines,
extra=LOG_EXTRA,
)
parsed_actions: List[List[float]] = []
for line in lines:
try:
values = [float(val.strip()) for val in line.split(',') if val.strip()]
except ValueError:
st.error(f"无法解析动作行:{line}")
parsed_actions = []
break
if len(values) != len(specs):
st.error(f"动作维度不匹配(期望 {len(specs)} 个值):{line}")
parsed_actions = []
break
parsed_actions.append(values)
if parsed_actions:
LOGGER.info(
"离线调参(批量)解析动作成功,数量=%s",
len(parsed_actions),
extra=LOG_EXTRA,
)
baseline_weights = app_cfg.agent_weights.as_dict()
for agent in agent_objects:
baseline_weights.setdefault(agent.name, 1.0)
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
if not universe_env:
st.error("请先指定至少一个股票代码。")
else:
bt_cfg_env = BtConfig(
id="decision_env_streamlit_batch",
name="DecisionEnv Batch",
start_date=start_date,
end_date=end_date,
universe=universe_env,
params={
"target": target,
"stop": stop,
"hold_days": int(hold_days),
},
method=app_cfg.decision_method,
)
env = DecisionEnv(
bt_config=bt_cfg_env,
parameter_specs=specs,
baseline_weights=baseline_weights,
disable_departments=disable_departments,
)
results: List[Dict[str, object]] = []
resolved_experiment_id = experiment_id or str(uuid.uuid4())
resolved_strategy = strategy_label or "DecisionEnv"
LOGGER.debug(
"离线调参(批量)启动 DecisionEnvcfg=%s 动作组=%s",
bt_cfg_env,
len(parsed_actions),
extra=LOG_EXTRA,
)
with st.spinner("正在批量执行调参……"):
for idx, action_vals in enumerate(parsed_actions, start=1):
env.reset()
try:
observation, reward, done, info = env.step(action_vals)
except Exception as exc: # noqa: BLE001
LOGGER.exception("批量调参失败", extra=LOG_EXTRA)
results.append(
{
"序号": idx,
"动作": action_vals,
"状态": "error",
"错误": str(exc),
}
)
continue
if observation.get("failure"):
results.append(
{
"序号": idx,
"动作": action_vals,
"状态": "failure",
"奖励": -1.0,
}
)
else:
LOGGER.info(
"离线调参(批量)第 %s 组完成reward=%.4f obs=%s",
idx,
reward,
observation,
extra=LOG_EXTRA,
)
action_payload = {
label: value
for label, value in zip(spec_labels, action_vals)
}
metrics_payload = dict(observation)
metrics_payload["reward"] = reward
metrics_payload["department_controls"] = info.get("department_controls")
weights_payload = info.get("weights", {})
try:
log_tuning_result(
experiment_id=resolved_experiment_id,
strategy=resolved_strategy,
action=action_payload,
reward=reward,
metrics=metrics_payload,
weights=weights_payload,
)
except Exception: # noqa: BLE001
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
results.append(
{
"序号": idx,
"动作": action_payload,
"状态": "ok",
"总收益": observation.get("total_return", 0.0),
"最大回撤": observation.get("max_drawdown", 0.0),
"波动率": observation.get("volatility", 0.0),
"奖励": reward,
"权重": weights_payload,
"部门控制": info.get("department_controls"),
}
)
st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = {
"results": results,
"selectable": [
row
for row in results
if row.get("状态") == "ok" and row.get("权重")
],
"experiment_id": resolved_experiment_id,
"strategy_label": resolved_strategy,
}
batch_just_ran = True
LOGGER.info(
"离线调参(批量)执行结束,总结果条数=%s",
len(results),
extra=LOG_EXTRA,
)
batch_state = st.session_state.get(_DECISION_ENV_BATCH_RESULTS_KEY)
if batch_state:
results = batch_state.get("results") or []
if results:
if batch_just_ran:
st.success("批量调参完成")
else:
st.success("批量调参结果(最近一次运行)")
st.caption(
f"实验 ID{batch_state.get('experiment_id', '-') } | 策略:{batch_state.get('strategy_label', 'DecisionEnv')}"
)
results_df = pd.DataFrame(results)
st.write("批量调参结果:")
st.dataframe(results_df, hide_index=True, width='stretch')
selectable = batch_state.get("selectable") or []
if selectable:
option_labels = [
f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}"
for row in selectable
]
selected_label = st.selectbox(
"选择要保存的记录",
option_labels,
key="decision_env_batch_select",
)
selected_row = None
for label, row in zip(option_labels, selectable):
if label == selected_label:
selected_row = row
break
if selected_row and st.button(
"保存所选权重为默认配置",
key="save_decision_env_weights_batch",
):
try:
app_cfg.agent_weights.update_from_dict(selected_row.get("权重", {}))
save_config(app_cfg)
except Exception as exc: # noqa: BLE001
LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"写入配置失败:{exc}")
else:
st.success(
f"已将序号 {selected_row['序号']} 的权重写入 config.json"
)
else:
st.caption("暂无成功的结果可供保存。")
else:
st.caption("批量调参在最近一次执行中未产生结果。")
if st.button("清除批量调参结果", key="clear_decision_env_batch"):
st.session_state.pop(_DECISION_ENV_BATCH_RESULTS_KEY, None)
st.session_state.pop("decision_env_batch_select", None)
st.success("已清除批量调参结果缓存。")