This commit is contained in:
sam 2025-09-29 22:29:00 +08:00
parent b3f2f5b4fc
commit 37fd7f80ce

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import sys import sys
import time
from dataclasses import asdict from dataclasses import asdict
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
from pathlib import Path from pathlib import Path
@ -62,58 +61,42 @@ from app.utils.tuning import log_tuning_result
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "ui"} LOG_EXTRA = {"stage": "ui"}
_SIDEBAR_THROTTLE_SECONDS = 0.75 _DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result"
_DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
_DASHBOARD_CONTAINERS: Optional[tuple[object, object]] = None
def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None: _DASHBOARD_ELEMENTS: Optional[Dict[str, object]] = None
_update_dashboard_sidebar(metrics, throttled=True)
def render_global_dashboard() -> None: def render_global_dashboard() -> None:
"""Render a persistent sidebar with realtime LLM stats and recent decisions.""" """Render a persistent sidebar with realtime LLM stats and recent decisions."""
global _DASHBOARD_CONTAINERS
global _DASHBOARD_ELEMENTS
metrics_container = st.sidebar.container() metrics_container = st.sidebar.container()
decisions_container = st.sidebar.container() decisions_container = st.sidebar.container()
st.session_state["dashboard_containers"] = (metrics_container, decisions_container) _DASHBOARD_CONTAINERS = (metrics_container, decisions_container)
_ensure_dashboard_elements(metrics_container, decisions_container) _DASHBOARD_ELEMENTS = _ensure_dashboard_elements(metrics_container, decisions_container)
if not st.session_state.get("dashboard_listener_registered"):
register_llm_metrics_listener(_sidebar_metrics_listener)
st.session_state["dashboard_listener_registered"] = True
_update_dashboard_sidebar() _update_dashboard_sidebar()
def _update_dashboard_sidebar( def _update_dashboard_sidebar(
metrics: Optional[Dict[str, object]] = None, metrics: Optional[Dict[str, object]] = None,
*,
throttled: bool = False,
) -> None: ) -> None:
containers = st.session_state.get("dashboard_containers") global _DASHBOARD_CONTAINERS
global _DASHBOARD_ELEMENTS
containers = _DASHBOARD_CONTAINERS
if not containers: if not containers:
return return
metrics_container, decisions_container = containers metrics_container, decisions_container = containers
elements = st.session_state.get("dashboard_elements") elements = _DASHBOARD_ELEMENTS
if elements is None: if elements is None:
elements = _ensure_dashboard_elements(metrics_container, decisions_container) elements = _ensure_dashboard_elements(metrics_container, decisions_container)
_DASHBOARD_ELEMENTS = elements
if throttled:
now = time.monotonic()
last_update = st.session_state.get("dashboard_last_update_ts", 0.0)
if now - last_update < _SIDEBAR_THROTTLE_SECONDS:
if metrics is not None:
st.session_state["dashboard_pending_metrics"] = metrics
return
st.session_state["dashboard_last_update_ts"] = now
else:
st.session_state["dashboard_last_update_ts"] = time.monotonic()
if metrics is None: if metrics is None:
metrics = st.session_state.pop("dashboard_pending_metrics", None) metrics = snapshot_llm_metrics()
if metrics is None:
metrics = snapshot_llm_metrics()
else:
st.session_state.pop("dashboard_pending_metrics", None)
metrics = metrics or snapshot_llm_metrics()
elements["metrics_calls"].metric("LLM 调用", metrics.get("total_calls", 0)) elements["metrics_calls"].metric("LLM 调用", metrics.get("total_calls", 0))
elements["metrics_prompt"].metric("Prompt Tokens", metrics.get("total_prompt_tokens", 0)) elements["metrics_prompt"].metric("Prompt Tokens", metrics.get("total_prompt_tokens", 0))
@ -160,10 +143,6 @@ def _update_dashboard_sidebar(
def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[str, object]: def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[str, object]:
elements = st.session_state.get("dashboard_elements")
if elements:
return elements
metrics_container.header("系统监控") metrics_container.header("系统监控")
col_a, col_b, col_c = metrics_container.columns(3) col_a, col_b, col_c = metrics_container.columns(3)
metrics_calls = col_a.empty() metrics_calls = col_a.empty()
@ -184,7 +163,6 @@ def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[s
"model_distribution": model_distribution, "model_distribution": model_distribution,
"decisions_list": decisions_list, "decisions_list": decisions_list,
} }
st.session_state["dashboard_elements"] = elements
return elements return elements
def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]: def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]:
@ -835,12 +813,20 @@ def render_backtest() -> None:
action_values.append(action_val) action_values.append(action_val)
run_decision_env = st.button("执行单次调参", key="run_decision_env_button") run_decision_env = st.button("执行单次调参", key="run_decision_env_button")
just_finished_single = False
if run_decision_env: if run_decision_env:
if not selected_agents: if not selected_agents:
st.warning("请至少选择一个代理进行调参。") st.warning("请至少选择一个代理进行调参。")
elif not range_valid: elif not range_valid:
st.error("请确保所有代理的最大权重大于最小权重。") st.error("请确保所有代理的最大权重大于最小权重。")
else: else:
LOGGER.info(
"离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s",
selected_agents,
action_values,
disable_departments,
extra=LOG_EXTRA,
)
baseline_weights = cfg.agent_weights.as_dict() baseline_weights = cfg.agent_weights.as_dict()
for agent in agent_objects: for agent in agent_objects:
baseline_weights.setdefault(agent.name, 1.0) baseline_weights.setdefault(agent.name, 1.0)
@ -869,74 +855,126 @@ def render_backtest() -> None:
disable_departments=disable_departments, disable_departments=disable_departments,
) )
env.reset() env.reset()
LOGGER.debug(
"离线调参(单次)启动 DecisionEnvcfg=%s 参数维度=%s",
bt_cfg_env,
len(specs),
extra=LOG_EXTRA,
)
with st.spinner("正在执行离线调参……"): with st.spinner("正在执行离线调参……"):
try: try:
observation, reward, done, info = env.step(action_values) observation, reward, done, info = env.step(action_values)
LOGGER.info(
"离线调参单次完成obs=%s reward=%.4f done=%s",
observation,
reward,
done,
extra=LOG_EXTRA,
)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA) LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA)
st.error(f"离线调参失败:{exc}") st.error(f"离线调参失败:{exc}")
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
else: else:
if observation.get("failure"): if observation.get("failure"):
st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。") st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。")
st.json(observation) st.json(observation)
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
else: else:
st.success("离线调参完成") resolved_experiment_id = experiment_id or str(uuid.uuid4())
col_metrics = st.columns(4) resolved_strategy = strategy_label or "DecisionEnv"
col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}")
col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}")
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
col_metrics[3].metric("奖励", f"{reward:+.4f}")
st.write("调参后权重:")
weights_dict = info.get("weights", {})
st.json(weights_dict)
action_payload = { action_payload = {
name: value name: value
for name, value in zip(selected_agents, action_values) for name, value in zip(selected_agents, action_values)
} }
metrics_payload = dict(observation) metrics_payload = dict(observation)
metrics_payload["reward"] = reward metrics_payload["reward"] = reward
log_success = False
try: try:
log_tuning_result( log_tuning_result(
experiment_id=experiment_id or str(uuid.uuid4()), experiment_id=resolved_experiment_id,
strategy=strategy_label or "DecisionEnv", strategy=resolved_strategy,
action=action_payload, action=action_payload,
reward=reward, reward=reward,
metrics=metrics_payload, metrics=metrics_payload,
weights=weights_dict, weights=info.get("weights", {}),
) )
st.caption("调参结果已写入 tuning_results 表。")
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA) LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
else:
log_success = True
LOGGER.info(
"离线调参单次日志写入成功experiment=%s strategy=%s",
resolved_experiment_id,
resolved_strategy,
extra=LOG_EXTRA,
)
st.session_state[_DECISION_ENV_SINGLE_RESULT_KEY] = {
"observation": dict(observation),
"reward": float(reward),
"weights": info.get("weights", {}),
"nav_series": info.get("nav_series"),
"trades": info.get("trades"),
"selected_agents": list(selected_agents),
"action_values": list(action_values),
"experiment_id": resolved_experiment_id,
"strategy_label": resolved_strategy,
"logged": log_success,
}
just_finished_single = True
single_result = st.session_state.get(_DECISION_ENV_SINGLE_RESULT_KEY)
if single_result:
if just_finished_single:
st.success("离线调参完成")
else:
st.success("离线调参结果(最近一次运行)")
st.caption(
f"实验 ID{single_result.get('experiment_id', '-') } | 策略:{single_result.get('strategy_label', 'DecisionEnv')}"
)
observation = single_result.get("observation", {})
reward = float(single_result.get("reward", 0.0))
col_metrics = st.columns(4)
col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}")
col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}")
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
col_metrics[3].metric("奖励", f"{reward:+.4f}")
if weights_dict: weights_dict = single_result.get("weights") or {}
if st.button( if weights_dict:
"保存这些权重为默认配置", st.write("调参后权重:")
key="save_decision_env_weights_single", st.json(weights_dict)
): if st.button("保存这些权重为默认配置", key="save_decision_env_weights_single"):
try: try:
cfg.agent_weights.update_from_dict(weights_dict) cfg.agent_weights.update_from_dict(weights_dict)
save_config(cfg) save_config(cfg)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)}) LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"写入配置失败:{exc}") st.error(f"写入配置失败:{exc}")
else: else:
st.success("代理权重已写入 config.json") st.success("代理权重已写入 config.json")
nav_series = info.get("nav_series") if single_result.get("logged"):
if nav_series: st.caption("调参结果已写入 tuning_results 表。")
try:
nav_df = pd.DataFrame(nav_series) nav_series = single_result.get("nav_series") or []
if {"trade_date", "nav"}.issubset(nav_df.columns): if nav_series:
nav_df = nav_df.sort_values("trade_date") try:
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"]) nav_df = pd.DataFrame(nav_series)
st.line_chart(nav_df.set_index("trade_date")["nav"], height=220) if {"trade_date", "nav"}.issubset(nav_df.columns):
except Exception: # noqa: BLE001 nav_df = nav_df.sort_values("trade_date")
LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA) nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"])
trades = info.get("trades") st.line_chart(nav_df.set_index("trade_date")["nav"], height=220)
if trades: except Exception: # noqa: BLE001
st.write("成交记录:") LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA)
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
trades = single_result.get("trades") or []
if trades:
st.write("成交记录:")
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
if st.button("清除单次调参结果", key="clear_decision_env_single"):
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
st.success("已清除单次调参结果缓存。")
st.divider() st.divider()
st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。") st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。")
@ -954,16 +992,28 @@ def render_backtest() -> None:
key="decision_env_batch_actions", key="decision_env_batch_actions",
) )
run_batch = st.button("批量执行调参", key="run_decision_env_batch") run_batch = st.button("批量执行调参", key="run_decision_env_batch")
batch_just_ran = False
if run_batch: if run_batch:
if not selected_agents: if not selected_agents:
st.warning("请先选择调参代理。") st.warning("请先选择调参代理。")
elif not range_valid: elif not range_valid:
st.error("请确保所有代理的最大权重大于最小权重。") st.error("请确保所有代理的最大权重大于最小权重。")
else: else:
LOGGER.info(
"离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s",
selected_agents,
disable_departments,
extra=LOG_EXTRA,
)
lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()] lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()]
if not lines: if not lines:
st.warning("请在文本框中输入至少一组动作。") st.warning("请在文本框中输入至少一组动作。")
else: else:
LOGGER.debug(
"离线调参(批量)原始输入=%s",
lines,
extra=LOG_EXTRA,
)
parsed_actions: List[List[float]] = [] parsed_actions: List[List[float]] = []
for line in lines: for line in lines:
try: try:
@ -978,6 +1028,11 @@ def render_backtest() -> None:
break break
parsed_actions.append(values) parsed_actions.append(values)
if parsed_actions: if parsed_actions:
LOGGER.info(
"离线调参(批量)解析动作成功,数量=%s",
len(parsed_actions),
extra=LOG_EXTRA,
)
baseline_weights = cfg.agent_weights.as_dict() baseline_weights = cfg.agent_weights.as_dict()
for agent in agent_objects: for agent in agent_objects:
baseline_weights.setdefault(agent.name, 1.0) baseline_weights.setdefault(agent.name, 1.0)
@ -1006,6 +1061,14 @@ def render_backtest() -> None:
disable_departments=disable_departments, disable_departments=disable_departments,
) )
results: List[Dict[str, object]] = [] results: List[Dict[str, object]] = []
resolved_experiment_id = experiment_id or str(uuid.uuid4())
resolved_strategy = strategy_label or "DecisionEnv"
LOGGER.debug(
"离线调参(批量)启动 DecisionEnvcfg=%s 动作组=%s",
bt_cfg_env,
len(parsed_actions),
extra=LOG_EXTRA,
)
with st.spinner("正在批量执行调参……"): with st.spinner("正在批量执行调参……"):
for idx, action_vals in enumerate(parsed_actions, start=1): for idx, action_vals in enumerate(parsed_actions, start=1):
env.reset() env.reset()
@ -1032,6 +1095,13 @@ def render_backtest() -> None:
} }
) )
else: else:
LOGGER.info(
"离线调参(批量)第 %s 组完成reward=%.4f obs=%s",
idx,
reward,
observation,
extra=LOG_EXTRA,
)
action_payload = { action_payload = {
name: value name: value
for name, value in zip(selected_agents, action_vals) for name, value in zip(selected_agents, action_vals)
@ -1041,8 +1111,8 @@ def render_backtest() -> None:
weights_payload = info.get("weights", {}) weights_payload = info.get("weights", {})
try: try:
log_tuning_result( log_tuning_result(
experiment_id=experiment_id or str(uuid.uuid4()), experiment_id=resolved_experiment_id,
strategy=strategy_label or "DecisionEnv", strategy=resolved_strategy,
action=action_payload, action=action_payload,
reward=reward, reward=reward,
metrics=metrics_payload, metrics=metrics_payload,
@ -1062,46 +1132,74 @@ def render_backtest() -> None:
"权重": weights_payload, "权重": weights_payload,
} }
) )
if results: st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = {
st.write("批量调参结果:") "results": results,
results_df = pd.DataFrame(results) "selectable": [
st.dataframe(results_df, hide_index=True, width='stretch')
selectable = [
row row
for row in results for row in results
if row.get("状态") == "ok" and row.get("权重") if row.get("状态") == "ok" and row.get("权重")
] ],
if selectable: "experiment_id": resolved_experiment_id,
option_labels = [ "strategy_label": resolved_strategy,
f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}" }
for row in selectable batch_just_ran = True
] LOGGER.info(
selected_label = st.selectbox( "离线调参(批量)执行结束,总结果条数=%s",
"选择要保存的记录", len(results),
option_labels, extra=LOG_EXTRA,
key="decision_env_batch_select", )
) batch_state = st.session_state.get(_DECISION_ENV_BATCH_RESULTS_KEY)
selected_row = None if batch_state:
for label, row in zip(option_labels, selectable): results = batch_state.get("results") or []
if label == selected_label: if results:
selected_row = row if batch_just_ran:
break st.success("批量调参完成")
if selected_row and st.button( else:
"保存所选权重为默认配置", st.success("批量调参结果(最近一次运行)")
key="save_decision_env_weights_batch", st.caption(
): f"实验 ID{batch_state.get('experiment_id', '-') } | 策略:{batch_state.get('strategy_label', 'DecisionEnv')}"
try: )
cfg.agent_weights.update_from_dict(selected_row.get("权重", {})) results_df = pd.DataFrame(results)
save_config(cfg) st.write("批量调参结果:")
except Exception as exc: # noqa: BLE001 st.dataframe(results_df, hide_index=True, width='stretch')
LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)}) selectable = batch_state.get("selectable") or []
st.error(f"写入配置失败:{exc}") if selectable:
else: option_labels = [
st.success( f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}"
f"已将序号 {selected_row['序号']} 的权重写入 config.json" for row in selectable
) ]
else: selected_label = st.selectbox(
st.caption("暂无成功的结果可供保存。") "选择要保存的记录",
option_labels,
key="decision_env_batch_select",
)
selected_row = None
for label, row in zip(option_labels, selectable):
if label == selected_label:
selected_row = row
break
if selected_row and st.button(
"保存所选权重为默认配置",
key="save_decision_env_weights_batch",
):
try:
cfg.agent_weights.update_from_dict(selected_row.get("权重", {}))
save_config(cfg)
except Exception as exc: # noqa: BLE001
LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"写入配置失败:{exc}")
else:
st.success(
f"已将序号 {selected_row['序号']} 的权重写入 config.json"
)
else:
st.caption("暂无成功的结果可供保存。")
else:
st.caption("批量调参在最近一次执行中未产生结果。")
if st.button("清除批量调参结果", key="clear_decision_env_batch"):
st.session_state.pop(_DECISION_ENV_BATCH_RESULTS_KEY, None)
st.session_state.pop("decision_env_batch_select", None)
st.success("已清除批量调参结果缓存。")
def render_settings() -> None: def render_settings() -> None:
@ -1673,7 +1771,7 @@ def render_tests() -> None:
] ]
) )
candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10)) candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10))
st.plotly_chart(candle_fig, width='stretch') st.plotly_chart(candle_fig, use_container_width=True)
vol_fig = px.bar( vol_fig = px.bar(
df_reset, df_reset,
@ -1683,7 +1781,7 @@ def render_tests() -> None:
title="成交量", title="成交量",
) )
vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10)) vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
st.plotly_chart(vol_fig, width='stretch') st.plotly_chart(vol_fig, use_container_width=True)
amt_fig = px.bar( amt_fig = px.bar(
df_reset, df_reset,
@ -1693,7 +1791,7 @@ def render_tests() -> None:
title="成交额", title="成交额",
) )
amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10)) amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
st.plotly_chart(amt_fig, width='stretch') st.plotly_chart(amt_fig, use_container_width=True)
df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str) df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str)
box_fig = px.box( box_fig = px.box(
@ -1704,7 +1802,7 @@ def render_tests() -> None:
title="月度收盘价分布", title="月度收盘价分布",
) )
box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10)) box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10))
st.plotly_chart(box_fig, width='stretch') st.plotly_chart(box_fig, use_container_width=True)
st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。") st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。")
st.dataframe(df_reset.tail(20), width='stretch') st.dataframe(df_reset.tail(20), width='stretch')