update
This commit is contained in:
parent
e234d66687
commit
dc2d82f685
@ -370,135 +370,137 @@ def render_backtest_review() -> None:
|
|||||||
action_values.append(action_val)
|
action_values.append(action_val)
|
||||||
|
|
||||||
controls_valid = True
|
controls_valid = True
|
||||||
with st.expander("部门 LLM 参数", expanded=False):
|
|
||||||
dept_codes = sorted(app_cfg.departments.keys())
|
st.divider()
|
||||||
if not dept_codes:
|
st.subheader("部门参数调整(可选)")
|
||||||
st.caption("当前未配置部门。")
|
dept_codes = sorted(app_cfg.departments.keys())
|
||||||
else:
|
if not dept_codes:
|
||||||
selected_departments = st.multiselect(
|
st.caption("当前未配置部门。")
|
||||||
"选择需要调整的部门",
|
else:
|
||||||
dept_codes,
|
selected_departments = st.multiselect(
|
||||||
default=[],
|
"选择需要调整的部门",
|
||||||
key="decision_env_departments",
|
dept_codes,
|
||||||
|
default=[],
|
||||||
|
key="decision_env_departments",
|
||||||
|
)
|
||||||
|
tool_policy_values = ["auto", "none", "required"]
|
||||||
|
for dept_code in selected_departments:
|
||||||
|
settings = app_cfg.departments.get(dept_code)
|
||||||
|
if not settings:
|
||||||
|
continue
|
||||||
|
st.subheader(f"部门:{settings.title or dept_code}")
|
||||||
|
base_temp = 0.2
|
||||||
|
if settings.llm and settings.llm.primary and settings.llm.primary.temperature is not None:
|
||||||
|
base_temp = float(settings.llm.primary.temperature)
|
||||||
|
prefix = f"decision_env_dept_{dept_code}"
|
||||||
|
col_tmin, col_tmax, col_tslider = st.columns([1, 1, 2])
|
||||||
|
temp_min = col_tmin.number_input(
|
||||||
|
"温度最小值",
|
||||||
|
min_value=0.0,
|
||||||
|
max_value=2.0,
|
||||||
|
value=max(0.0, base_temp - 0.3),
|
||||||
|
step=0.05,
|
||||||
|
key=f"{prefix}_temp_min",
|
||||||
)
|
)
|
||||||
tool_policy_values = ["auto", "none", "required"]
|
temp_max = col_tmax.number_input(
|
||||||
for dept_code in selected_departments:
|
"温度最大值",
|
||||||
settings = app_cfg.departments.get(dept_code)
|
min_value=0.0,
|
||||||
if not settings:
|
max_value=2.0,
|
||||||
continue
|
value=min(2.0, base_temp + 0.3),
|
||||||
st.subheader(f"部门:{settings.title or dept_code}")
|
step=0.05,
|
||||||
base_temp = 0.2
|
key=f"{prefix}_temp_max",
|
||||||
if settings.llm and settings.llm.primary and settings.llm.primary.temperature is not None:
|
)
|
||||||
base_temp = float(settings.llm.primary.temperature)
|
if temp_max <= temp_min:
|
||||||
prefix = f"decision_env_dept_{dept_code}"
|
controls_valid = False
|
||||||
col_tmin, col_tmax, col_tslider = st.columns([1, 1, 2])
|
st.warning("温度最大值必须大于最小值。")
|
||||||
temp_min = col_tmin.number_input(
|
temp_max = min(2.0, temp_min + 0.01)
|
||||||
"温度最小值",
|
span = temp_max - temp_min
|
||||||
min_value=0.0,
|
if span <= 0:
|
||||||
max_value=2.0,
|
ratio_default = 0.0
|
||||||
value=max(0.0, base_temp - 0.3),
|
else:
|
||||||
step=0.05,
|
clamped = min(max(base_temp, temp_min), temp_max)
|
||||||
key=f"{prefix}_temp_min",
|
ratio_default = (clamped - temp_min) / span
|
||||||
|
temp_action = col_tslider.slider(
|
||||||
|
"动作值(映射至温度区间)",
|
||||||
|
min_value=0.0,
|
||||||
|
max_value=1.0,
|
||||||
|
value=float(ratio_default),
|
||||||
|
step=0.01,
|
||||||
|
key=f"{prefix}_temp_action",
|
||||||
|
)
|
||||||
|
specs.append(
|
||||||
|
ParameterSpec(
|
||||||
|
name=f"dept_temperature_{dept_code}",
|
||||||
|
target=f"department.{dept_code}.temperature",
|
||||||
|
minimum=temp_min,
|
||||||
|
maximum=temp_max,
|
||||||
)
|
)
|
||||||
temp_max = col_tmax.number_input(
|
)
|
||||||
"温度最大值",
|
spec_labels.append(f"department:{dept_code}:temperature")
|
||||||
min_value=0.0,
|
action_values.append(temp_action)
|
||||||
max_value=2.0,
|
|
||||||
value=min(2.0, base_temp + 0.3),
|
col_tool, col_hint = st.columns([1, 2])
|
||||||
step=0.05,
|
tool_choice = col_tool.selectbox(
|
||||||
key=f"{prefix}_temp_max",
|
"函数调用策略",
|
||||||
|
tool_policy_values,
|
||||||
|
index=tool_policy_values.index("auto"),
|
||||||
|
key=f"{prefix}_tool_choice",
|
||||||
|
)
|
||||||
|
col_hint.caption("映射提示:0→auto,0.5→none,1→required。")
|
||||||
|
if len(tool_policy_values) > 1:
|
||||||
|
tool_value = tool_policy_values.index(tool_choice) / (len(tool_policy_values) - 1)
|
||||||
|
else:
|
||||||
|
tool_value = 0.0
|
||||||
|
specs.append(
|
||||||
|
ParameterSpec(
|
||||||
|
name=f"dept_tool_{dept_code}",
|
||||||
|
target=f"department.{dept_code}.function_policy",
|
||||||
|
values=tool_policy_values,
|
||||||
)
|
)
|
||||||
if temp_max <= temp_min:
|
)
|
||||||
controls_valid = False
|
spec_labels.append(f"department:{dept_code}:tool_choice")
|
||||||
st.warning("温度最大值必须大于最小值。")
|
action_values.append(tool_value)
|
||||||
temp_max = min(2.0, temp_min + 0.01)
|
|
||||||
span = temp_max - temp_min
|
template_id = (settings.prompt_template_id or f"{dept_code}_dept").strip()
|
||||||
if span <= 0:
|
versions = [ver for ver in TemplateRegistry.list_versions(template_id) if isinstance(ver, str)]
|
||||||
ratio_default = 0.0
|
if versions:
|
||||||
else:
|
active_version = TemplateRegistry.get_active_version(template_id)
|
||||||
clamped = min(max(base_temp, temp_min), temp_max)
|
default_version = (
|
||||||
ratio_default = (clamped - temp_min) / span
|
settings.prompt_template_version
|
||||||
temp_action = col_tslider.slider(
|
or active_version
|
||||||
"动作值(映射至温度区间)",
|
or versions[0]
|
||||||
min_value=0.0,
|
)
|
||||||
max_value=1.0,
|
try:
|
||||||
value=float(ratio_default),
|
default_index = versions.index(default_version)
|
||||||
step=0.01,
|
except ValueError:
|
||||||
key=f"{prefix}_temp_action",
|
default_index = 0
|
||||||
|
version_choice = st.selectbox(
|
||||||
|
"提示模板版本",
|
||||||
|
versions,
|
||||||
|
index=default_index,
|
||||||
|
key=f"{prefix}_template_version",
|
||||||
|
help="离散动作将按版本列表顺序映射,可用于强化学习优化。",
|
||||||
|
)
|
||||||
|
selected_index = versions.index(version_choice)
|
||||||
|
ratio = (
|
||||||
|
0.0
|
||||||
|
if len(versions) == 1
|
||||||
|
else selected_index / (len(versions) - 1)
|
||||||
)
|
)
|
||||||
specs.append(
|
specs.append(
|
||||||
ParameterSpec(
|
ParameterSpec(
|
||||||
name=f"dept_temperature_{dept_code}",
|
name=f"dept_prompt_version_{dept_code}",
|
||||||
target=f"department.{dept_code}.temperature",
|
target=f"department.{dept_code}.prompt_template_version",
|
||||||
minimum=temp_min,
|
values=list(versions),
|
||||||
maximum=temp_max,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
spec_labels.append(f"department:{dept_code}:temperature")
|
spec_labels.append(f"department:{dept_code}:prompt_version")
|
||||||
action_values.append(temp_action)
|
action_values.append(ratio)
|
||||||
|
st.caption(
|
||||||
col_tool, col_hint = st.columns([1, 2])
|
f"激活版本:{active_version or '默认'} | 当前选择:{version_choice}"
|
||||||
tool_choice = col_tool.selectbox(
|
|
||||||
"函数调用策略",
|
|
||||||
tool_policy_values,
|
|
||||||
index=tool_policy_values.index("auto"),
|
|
||||||
key=f"{prefix}_tool_choice",
|
|
||||||
)
|
)
|
||||||
col_hint.caption("映射提示:0→auto,0.5→none,1→required。")
|
else:
|
||||||
if len(tool_policy_values) > 1:
|
st.caption("当前模板未注册可选提示词版本,继续沿用激活版本。")
|
||||||
tool_value = tool_policy_values.index(tool_choice) / (len(tool_policy_values) - 1)
|
|
||||||
else:
|
|
||||||
tool_value = 0.0
|
|
||||||
specs.append(
|
|
||||||
ParameterSpec(
|
|
||||||
name=f"dept_tool_{dept_code}",
|
|
||||||
target=f"department.{dept_code}.function_policy",
|
|
||||||
values=tool_policy_values,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
spec_labels.append(f"department:{dept_code}:tool_choice")
|
|
||||||
action_values.append(tool_value)
|
|
||||||
|
|
||||||
template_id = (settings.prompt_template_id or f"{dept_code}_dept").strip()
|
|
||||||
versions = [ver for ver in TemplateRegistry.list_versions(template_id) if isinstance(ver, str)]
|
|
||||||
if versions:
|
|
||||||
active_version = TemplateRegistry.get_active_version(template_id)
|
|
||||||
default_version = (
|
|
||||||
settings.prompt_template_version
|
|
||||||
or active_version
|
|
||||||
or versions[0]
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
default_index = versions.index(default_version)
|
|
||||||
except ValueError:
|
|
||||||
default_index = 0
|
|
||||||
version_choice = st.selectbox(
|
|
||||||
"提示模板版本",
|
|
||||||
versions,
|
|
||||||
index=default_index,
|
|
||||||
key=f"{prefix}_template_version",
|
|
||||||
help="离散动作将按版本列表顺序映射,可用于强化学习优化。",
|
|
||||||
)
|
|
||||||
selected_index = versions.index(version_choice)
|
|
||||||
ratio = (
|
|
||||||
0.0
|
|
||||||
if len(versions) == 1
|
|
||||||
else selected_index / (len(versions) - 1)
|
|
||||||
)
|
|
||||||
specs.append(
|
|
||||||
ParameterSpec(
|
|
||||||
name=f"dept_prompt_version_{dept_code}",
|
|
||||||
target=f"department.{dept_code}.prompt_template_version",
|
|
||||||
values=list(versions),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
spec_labels.append(f"department:{dept_code}:prompt_version")
|
|
||||||
action_values.append(ratio)
|
|
||||||
st.caption(
|
|
||||||
f"激活版本:{active_version or '默认'} | 当前选择:{version_choice}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
st.caption("当前模板未注册可选提示词版本,继续沿用激活版本。")
|
|
||||||
|
|
||||||
if specs:
|
if specs:
|
||||||
st.caption("动作维度顺序:" + ",".join(spec_labels))
|
st.caption("动作维度顺序:" + ",".join(spec_labels))
|
||||||
@ -509,213 +511,6 @@ def render_backtest_review() -> None:
|
|||||||
help="关闭部门调用后不依赖外部 LLM 网络,仅根据规则代理权重模拟。",
|
help="关闭部门调用后不依赖外部 LLM 网络,仅根据规则代理权重模拟。",
|
||||||
)
|
)
|
||||||
|
|
||||||
st.divider()
|
|
||||||
st.subheader("单次调参")
|
|
||||||
|
|
||||||
run_decision_env = st.button("执行单次调参", key="run_decision_env_button")
|
|
||||||
just_finished_single = False
|
|
||||||
if run_decision_env:
|
|
||||||
if not specs:
|
|
||||||
st.warning("请至少配置一个动作维度(代理或部门参数)。")
|
|
||||||
elif selected_agents and not range_valid:
|
|
||||||
st.error("请确保所有代理的最大权重大于最小权重。")
|
|
||||||
elif not controls_valid:
|
|
||||||
st.error("请修正部门参数的取值范围。")
|
|
||||||
else:
|
|
||||||
LOGGER.info(
|
|
||||||
"离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s",
|
|
||||||
selected_agents,
|
|
||||||
action_values,
|
|
||||||
disable_departments,
|
|
||||||
extra=LOG_EXTRA,
|
|
||||||
)
|
|
||||||
baseline_weights = app_cfg.agent_weights.as_dict()
|
|
||||||
for agent in agent_objects:
|
|
||||||
baseline_weights.setdefault(agent.name, 1.0)
|
|
||||||
|
|
||||||
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
|
|
||||||
if not universe_env:
|
|
||||||
st.error("请先指定至少一个股票代码。")
|
|
||||||
else:
|
|
||||||
bt_cfg_env = BtConfig(
|
|
||||||
id="decision_env_streamlit",
|
|
||||||
name="DecisionEnv Streamlit",
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
universe=universe_env,
|
|
||||||
params={
|
|
||||||
"target": target,
|
|
||||||
"stop": stop,
|
|
||||||
"hold_days": int(hold_days),
|
|
||||||
},
|
|
||||||
method=app_cfg.decision_method,
|
|
||||||
)
|
|
||||||
env = DecisionEnv(
|
|
||||||
bt_config=bt_cfg_env,
|
|
||||||
parameter_specs=specs,
|
|
||||||
baseline_weights=baseline_weights,
|
|
||||||
disable_departments=disable_departments,
|
|
||||||
)
|
|
||||||
env.reset()
|
|
||||||
LOGGER.debug(
|
|
||||||
"离线调参(单次)启动 DecisionEnv:cfg=%s 参数维度=%s",
|
|
||||||
bt_cfg_env,
|
|
||||||
len(specs),
|
|
||||||
extra=LOG_EXTRA,
|
|
||||||
)
|
|
||||||
with st.spinner("正在执行离线调参……"):
|
|
||||||
try:
|
|
||||||
observation, reward, done, info = env.step(action_values)
|
|
||||||
LOGGER.info(
|
|
||||||
"离线调参(单次)完成,obs=%s reward=%.4f done=%s",
|
|
||||||
observation,
|
|
||||||
reward,
|
|
||||||
done,
|
|
||||||
extra=LOG_EXTRA,
|
|
||||||
)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA)
|
|
||||||
st.error(f"离线调参失败:{exc}")
|
|
||||||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
|
||||||
else:
|
|
||||||
if observation.get("failure"):
|
|
||||||
st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。")
|
|
||||||
st.json(observation)
|
|
||||||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
|
||||||
else:
|
|
||||||
resolved_experiment_id = experiment_id or str(uuid.uuid4())
|
|
||||||
resolved_strategy = strategy_label or "DecisionEnv"
|
|
||||||
action_payload = {
|
|
||||||
label: value for label, value in zip(spec_labels, action_values)
|
|
||||||
}
|
|
||||||
metrics_payload = dict(observation)
|
|
||||||
metrics_payload["reward"] = reward
|
|
||||||
metrics_payload["department_controls"] = info.get("department_controls")
|
|
||||||
log_success = False
|
|
||||||
try:
|
|
||||||
log_tuning_result(
|
|
||||||
experiment_id=resolved_experiment_id,
|
|
||||||
strategy=resolved_strategy,
|
|
||||||
action=action_payload,
|
|
||||||
reward=reward,
|
|
||||||
metrics=metrics_payload,
|
|
||||||
weights=info.get("weights", {}),
|
|
||||||
)
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
|
|
||||||
else:
|
|
||||||
log_success = True
|
|
||||||
LOGGER.info(
|
|
||||||
"离线调参(单次)日志写入成功:experiment=%s strategy=%s",
|
|
||||||
resolved_experiment_id,
|
|
||||||
resolved_strategy,
|
|
||||||
extra=LOG_EXTRA,
|
|
||||||
)
|
|
||||||
st.session_state[_DECISION_ENV_SINGLE_RESULT_KEY] = {
|
|
||||||
"observation": dict(observation),
|
|
||||||
"reward": float(reward),
|
|
||||||
"weights": info.get("weights", {}),
|
|
||||||
"department_controls": info.get("department_controls"),
|
|
||||||
"actions": action_payload,
|
|
||||||
"nav_series": info.get("nav_series"),
|
|
||||||
"trades": info.get("trades"),
|
|
||||||
"portfolio_snapshots": info.get("portfolio_snapshots"),
|
|
||||||
"portfolio_trades": info.get("portfolio_trades"),
|
|
||||||
"risk_breakdown": info.get("risk_breakdown"),
|
|
||||||
"spec_labels": list(spec_labels),
|
|
||||||
"action_values": list(action_values),
|
|
||||||
"experiment_id": resolved_experiment_id,
|
|
||||||
"strategy_label": resolved_strategy,
|
|
||||||
"logged": log_success,
|
|
||||||
}
|
|
||||||
just_finished_single = True
|
|
||||||
single_result = st.session_state.get(_DECISION_ENV_SINGLE_RESULT_KEY)
|
|
||||||
if single_result:
|
|
||||||
if just_finished_single:
|
|
||||||
st.success("离线调参完成")
|
|
||||||
else:
|
|
||||||
st.success("离线调参结果(最近一次运行)")
|
|
||||||
st.caption(
|
|
||||||
f"实验 ID:{single_result.get('experiment_id', '-') } | 策略:{single_result.get('strategy_label', 'DecisionEnv')}"
|
|
||||||
)
|
|
||||||
observation = single_result.get("observation", {})
|
|
||||||
reward = float(single_result.get("reward", 0.0))
|
|
||||||
col_metrics = st.columns(4)
|
|
||||||
col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}")
|
|
||||||
col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}")
|
|
||||||
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
|
|
||||||
col_metrics[3].metric("奖励", f"{reward:+.4f}")
|
|
||||||
|
|
||||||
turnover_ratio = float(observation.get("turnover", 0.0) or 0.0)
|
|
||||||
turnover_value = float(observation.get("turnover_value", 0.0) or 0.0)
|
|
||||||
risk_count = float(observation.get("risk_count", 0.0) or 0.0)
|
|
||||||
col_metrics_extra = st.columns(3)
|
|
||||||
col_metrics_extra[0].metric("平均换手率", f"{turnover_ratio:.2%}")
|
|
||||||
col_metrics_extra[1].metric("成交额", f"{turnover_value:,.0f}")
|
|
||||||
col_metrics_extra[2].metric("风险事件数", f"{int(risk_count)}")
|
|
||||||
|
|
||||||
weights_dict = single_result.get("weights") or {}
|
|
||||||
if weights_dict:
|
|
||||||
st.write("调参后权重:")
|
|
||||||
st.json(weights_dict)
|
|
||||||
if st.button("保存这些权重为默认配置", key="save_decision_env_weights_single"):
|
|
||||||
try:
|
|
||||||
app_cfg.agent_weights.update_from_dict(weights_dict)
|
|
||||||
save_config(app_cfg)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
|
|
||||||
st.error(f"写入配置失败:{exc}")
|
|
||||||
else:
|
|
||||||
st.success("代理权重已写入 config.json")
|
|
||||||
|
|
||||||
if single_result.get("logged"):
|
|
||||||
st.caption("调参结果已写入 tuning_results 表。")
|
|
||||||
|
|
||||||
nav_series = single_result.get("nav_series") or []
|
|
||||||
if nav_series:
|
|
||||||
try:
|
|
||||||
nav_df = pd.DataFrame(nav_series)
|
|
||||||
if {"trade_date", "nav"}.issubset(nav_df.columns):
|
|
||||||
nav_df = nav_df.sort_values("trade_date")
|
|
||||||
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"])
|
|
||||||
st.line_chart(nav_df.set_index("trade_date")["nav"], height=220)
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA)
|
|
||||||
|
|
||||||
trades = single_result.get("trades") or []
|
|
||||||
if trades:
|
|
||||||
st.write("成交记录:")
|
|
||||||
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
|
|
||||||
|
|
||||||
snapshots = single_result.get("portfolio_snapshots") or []
|
|
||||||
if snapshots:
|
|
||||||
with st.expander("投资组合快照", expanded=False):
|
|
||||||
st.dataframe(pd.DataFrame(snapshots), hide_index=True, width='stretch')
|
|
||||||
|
|
||||||
portfolio_trades = single_result.get("portfolio_trades") or []
|
|
||||||
if portfolio_trades:
|
|
||||||
with st.expander("组合成交明细", expanded=False):
|
|
||||||
st.dataframe(pd.DataFrame(portfolio_trades), hide_index=True, width='stretch')
|
|
||||||
|
|
||||||
risk_breakdown = single_result.get("risk_breakdown") or {}
|
|
||||||
if risk_breakdown:
|
|
||||||
with st.expander("风险事件统计", expanded=False):
|
|
||||||
st.json(risk_breakdown)
|
|
||||||
|
|
||||||
department_info = single_result.get("department_controls") or {}
|
|
||||||
if department_info:
|
|
||||||
with st.expander("部门控制参数", expanded=False):
|
|
||||||
st.json(department_info)
|
|
||||||
|
|
||||||
action_snapshot = single_result.get("actions") or {}
|
|
||||||
if action_snapshot:
|
|
||||||
with st.expander("动作明细", expanded=False):
|
|
||||||
st.json(action_snapshot)
|
|
||||||
|
|
||||||
if st.button("清除单次调参结果", key="clear_decision_env_single"):
|
|
||||||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
|
||||||
st.success("已清除单次调参结果缓存。")
|
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
st.subheader("自动探索(epsilon-greedy)")
|
st.subheader("自动探索(epsilon-greedy)")
|
||||||
col_ep, col_eps, col_seed = st.columns([1, 1, 1])
|
col_ep, col_eps, col_seed = st.columns([1, 1, 1])
|
||||||
@ -880,232 +675,4 @@ def render_backtest_review() -> None:
|
|||||||
st.session_state.pop(_DECISION_ENV_BANDIT_RESULTS_KEY, None)
|
st.session_state.pop(_DECISION_ENV_BANDIT_RESULTS_KEY, None)
|
||||||
st.success("已清除自动探索结果。")
|
st.success("已清除自动探索结果。")
|
||||||
|
|
||||||
st.divider()
|
|
||||||
st.subheader("批量调参")
|
|
||||||
st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。")
|
|
||||||
default_grid = "\n".join(
|
|
||||||
[
|
|
||||||
",".join(["0.2" for _ in specs]),
|
|
||||||
",".join(["0.5" for _ in specs]),
|
|
||||||
",".join(["0.8" for _ in specs]),
|
|
||||||
]
|
|
||||||
) if specs else ""
|
|
||||||
action_grid_raw = st.text_area(
|
|
||||||
"动作列表",
|
|
||||||
value=default_grid,
|
|
||||||
height=120,
|
|
||||||
key="decision_env_batch_actions",
|
|
||||||
)
|
|
||||||
run_batch = st.button("批量执行调参", key="run_decision_env_batch")
|
|
||||||
batch_just_ran = False
|
|
||||||
if run_batch:
|
|
||||||
if not specs:
|
|
||||||
st.warning("请至少配置一个动作维度。")
|
|
||||||
elif selected_agents and not range_valid:
|
|
||||||
st.error("请确保所有代理的最大权重大于最小权重。")
|
|
||||||
elif not controls_valid:
|
|
||||||
st.error("请修正部门参数的取值范围。")
|
|
||||||
else:
|
|
||||||
LOGGER.info(
|
|
||||||
"离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s",
|
|
||||||
selected_agents,
|
|
||||||
disable_departments,
|
|
||||||
extra=LOG_EXTRA,
|
|
||||||
)
|
|
||||||
lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()]
|
|
||||||
if not lines:
|
|
||||||
st.warning("请在文本框中输入至少一组动作。")
|
|
||||||
else:
|
|
||||||
LOGGER.debug(
|
|
||||||
"离线调参(批量)原始输入=%s",
|
|
||||||
lines,
|
|
||||||
extra=LOG_EXTRA,
|
|
||||||
)
|
|
||||||
parsed_actions: List[List[float]] = []
|
|
||||||
for line in lines:
|
|
||||||
try:
|
|
||||||
values = [float(val.strip()) for val in line.split(',') if val.strip()]
|
|
||||||
except ValueError:
|
|
||||||
st.error(f"无法解析动作行:{line}")
|
|
||||||
parsed_actions = []
|
|
||||||
break
|
|
||||||
if len(values) != len(specs):
|
|
||||||
st.error(f"动作维度不匹配(期望 {len(specs)} 个值):{line}")
|
|
||||||
parsed_actions = []
|
|
||||||
break
|
|
||||||
parsed_actions.append(values)
|
|
||||||
if parsed_actions:
|
|
||||||
LOGGER.info(
|
|
||||||
"离线调参(批量)解析动作成功,数量=%s",
|
|
||||||
len(parsed_actions),
|
|
||||||
extra=LOG_EXTRA,
|
|
||||||
)
|
|
||||||
baseline_weights = app_cfg.agent_weights.as_dict()
|
|
||||||
for agent in agent_objects:
|
|
||||||
baseline_weights.setdefault(agent.name, 1.0)
|
|
||||||
|
|
||||||
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
|
|
||||||
if not universe_env:
|
|
||||||
st.error("请先指定至少一个股票代码。")
|
|
||||||
else:
|
|
||||||
bt_cfg_env = BtConfig(
|
|
||||||
id="decision_env_streamlit_batch",
|
|
||||||
name="DecisionEnv Batch",
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
universe=universe_env,
|
|
||||||
params={
|
|
||||||
"target": target,
|
|
||||||
"stop": stop,
|
|
||||||
"hold_days": int(hold_days),
|
|
||||||
},
|
|
||||||
method=app_cfg.decision_method,
|
|
||||||
)
|
|
||||||
env = DecisionEnv(
|
|
||||||
bt_config=bt_cfg_env,
|
|
||||||
parameter_specs=specs,
|
|
||||||
baseline_weights=baseline_weights,
|
|
||||||
disable_departments=disable_departments,
|
|
||||||
)
|
|
||||||
results: List[Dict[str, object]] = []
|
|
||||||
resolved_experiment_id = experiment_id or str(uuid.uuid4())
|
|
||||||
resolved_strategy = strategy_label or "DecisionEnv"
|
|
||||||
LOGGER.debug(
|
|
||||||
"离线调参(批量)启动 DecisionEnv:cfg=%s 动作组=%s",
|
|
||||||
bt_cfg_env,
|
|
||||||
len(parsed_actions),
|
|
||||||
extra=LOG_EXTRA,
|
|
||||||
)
|
|
||||||
with st.spinner("正在批量执行调参……"):
|
|
||||||
for idx, action_vals in enumerate(parsed_actions, start=1):
|
|
||||||
env.reset()
|
|
||||||
try:
|
|
||||||
observation, reward, done, info = env.step(action_vals)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
LOGGER.exception("批量调参失败", extra=LOG_EXTRA)
|
|
||||||
results.append(
|
|
||||||
{
|
|
||||||
"序号": idx,
|
|
||||||
"动作": action_vals,
|
|
||||||
"状态": "error",
|
|
||||||
"错误": str(exc),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
if observation.get("failure"):
|
|
||||||
results.append(
|
|
||||||
{
|
|
||||||
"序号": idx,
|
|
||||||
"动作": action_vals,
|
|
||||||
"状态": "failure",
|
|
||||||
"奖励": -1.0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
LOGGER.info(
|
|
||||||
"离线调参(批量)第 %s 组完成,reward=%.4f obs=%s",
|
|
||||||
idx,
|
|
||||||
reward,
|
|
||||||
observation,
|
|
||||||
extra=LOG_EXTRA,
|
|
||||||
)
|
|
||||||
action_payload = {
|
|
||||||
label: value
|
|
||||||
for label, value in zip(spec_labels, action_vals)
|
|
||||||
}
|
|
||||||
metrics_payload = dict(observation)
|
|
||||||
metrics_payload["reward"] = reward
|
|
||||||
metrics_payload["department_controls"] = info.get("department_controls")
|
|
||||||
weights_payload = info.get("weights", {})
|
|
||||||
try:
|
|
||||||
log_tuning_result(
|
|
||||||
experiment_id=resolved_experiment_id,
|
|
||||||
strategy=resolved_strategy,
|
|
||||||
action=action_payload,
|
|
||||||
reward=reward,
|
|
||||||
metrics=metrics_payload,
|
|
||||||
weights=weights_payload,
|
|
||||||
)
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
|
|
||||||
results.append(
|
|
||||||
{
|
|
||||||
"序号": idx,
|
|
||||||
"动作": action_payload,
|
|
||||||
"状态": "ok",
|
|
||||||
"总收益": observation.get("total_return", 0.0),
|
|
||||||
"最大回撤": observation.get("max_drawdown", 0.0),
|
|
||||||
"波动率": observation.get("volatility", 0.0),
|
|
||||||
"奖励": reward,
|
|
||||||
"权重": weights_payload,
|
|
||||||
"部门控制": info.get("department_controls"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = {
|
|
||||||
"results": results,
|
|
||||||
"selectable": [
|
|
||||||
row
|
|
||||||
for row in results
|
|
||||||
if row.get("状态") == "ok" and row.get("权重")
|
|
||||||
],
|
|
||||||
"experiment_id": resolved_experiment_id,
|
|
||||||
"strategy_label": resolved_strategy,
|
|
||||||
}
|
|
||||||
batch_just_ran = True
|
|
||||||
LOGGER.info(
|
|
||||||
"离线调参(批量)执行结束,总结果条数=%s",
|
|
||||||
len(results),
|
|
||||||
extra=LOG_EXTRA,
|
|
||||||
)
|
|
||||||
batch_state = st.session_state.get(_DECISION_ENV_BATCH_RESULTS_KEY)
|
|
||||||
if batch_state:
|
|
||||||
results = batch_state.get("results") or []
|
|
||||||
if results:
|
|
||||||
if batch_just_ran:
|
|
||||||
st.success("批量调参完成")
|
|
||||||
else:
|
|
||||||
st.success("批量调参结果(最近一次运行)")
|
|
||||||
st.caption(
|
|
||||||
f"实验 ID:{batch_state.get('experiment_id', '-') } | 策略:{batch_state.get('strategy_label', 'DecisionEnv')}"
|
|
||||||
)
|
|
||||||
results_df = pd.DataFrame(results)
|
|
||||||
st.write("批量调参结果:")
|
|
||||||
st.dataframe(results_df, hide_index=True, width='stretch')
|
|
||||||
selectable = batch_state.get("selectable") or []
|
|
||||||
if selectable:
|
|
||||||
option_labels = [
|
|
||||||
f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}"
|
|
||||||
for row in selectable
|
|
||||||
]
|
|
||||||
selected_label = st.selectbox(
|
|
||||||
"选择要保存的记录",
|
|
||||||
option_labels,
|
|
||||||
key="decision_env_batch_select",
|
|
||||||
)
|
|
||||||
selected_row = None
|
|
||||||
for label, row in zip(option_labels, selectable):
|
|
||||||
if label == selected_label:
|
|
||||||
selected_row = row
|
|
||||||
break
|
|
||||||
if selected_row and st.button(
|
|
||||||
"保存所选权重为默认配置",
|
|
||||||
key="save_decision_env_weights_batch",
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
app_cfg.agent_weights.update_from_dict(selected_row.get("权重", {}))
|
|
||||||
save_config(app_cfg)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
|
|
||||||
st.error(f"写入配置失败:{exc}")
|
|
||||||
else:
|
|
||||||
st.success(
|
|
||||||
f"已将序号 {selected_row['序号']} 的权重写入 config.json"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
st.caption("暂无成功的结果可供保存。")
|
|
||||||
else:
|
|
||||||
st.caption("批量调参在最近一次执行中未产生结果。")
|
|
||||||
if st.button("清除批量调参结果", key="clear_decision_env_batch"):
|
|
||||||
st.session_state.pop(_DECISION_ENV_BATCH_RESULTS_KEY, None)
|
|
||||||
st.session_state.pop("decision_env_batch_select", None)
|
|
||||||
st.success("已清除批量调参结果缓存。")
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user