This commit is contained in:
sam 2025-09-29 22:29:00 +08:00
parent b3f2f5b4fc
commit 37fd7f80ce

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import sys
import time
from dataclasses import asdict
from datetime import date, datetime, timedelta
from pathlib import Path
@ -62,58 +61,42 @@ from app.utils.tuning import log_tuning_result
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "ui"}
_SIDEBAR_THROTTLE_SECONDS = 0.75
def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None:
_update_dashboard_sidebar(metrics, throttled=True)
_DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result"
_DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
_DASHBOARD_CONTAINERS: Optional[tuple[object, object]] = None
_DASHBOARD_ELEMENTS: Optional[Dict[str, object]] = None
def render_global_dashboard() -> None:
"""Render a persistent sidebar with realtime LLM stats and recent decisions."""
global _DASHBOARD_CONTAINERS
global _DASHBOARD_ELEMENTS
metrics_container = st.sidebar.container()
decisions_container = st.sidebar.container()
st.session_state["dashboard_containers"] = (metrics_container, decisions_container)
_ensure_dashboard_elements(metrics_container, decisions_container)
if not st.session_state.get("dashboard_listener_registered"):
register_llm_metrics_listener(_sidebar_metrics_listener)
st.session_state["dashboard_listener_registered"] = True
_DASHBOARD_CONTAINERS = (metrics_container, decisions_container)
_DASHBOARD_ELEMENTS = _ensure_dashboard_elements(metrics_container, decisions_container)
_update_dashboard_sidebar()
def _update_dashboard_sidebar(
metrics: Optional[Dict[str, object]] = None,
*,
throttled: bool = False,
) -> None:
containers = st.session_state.get("dashboard_containers")
global _DASHBOARD_CONTAINERS
global _DASHBOARD_ELEMENTS
containers = _DASHBOARD_CONTAINERS
if not containers:
return
metrics_container, decisions_container = containers
elements = st.session_state.get("dashboard_elements")
elements = _DASHBOARD_ELEMENTS
if elements is None:
elements = _ensure_dashboard_elements(metrics_container, decisions_container)
if throttled:
now = time.monotonic()
last_update = st.session_state.get("dashboard_last_update_ts", 0.0)
if now - last_update < _SIDEBAR_THROTTLE_SECONDS:
if metrics is not None:
st.session_state["dashboard_pending_metrics"] = metrics
return
st.session_state["dashboard_last_update_ts"] = now
else:
st.session_state["dashboard_last_update_ts"] = time.monotonic()
_DASHBOARD_ELEMENTS = elements
if metrics is None:
metrics = st.session_state.pop("dashboard_pending_metrics", None)
if metrics is None:
metrics = snapshot_llm_metrics()
else:
st.session_state.pop("dashboard_pending_metrics", None)
metrics = metrics or snapshot_llm_metrics()
metrics = snapshot_llm_metrics()
elements["metrics_calls"].metric("LLM 调用", metrics.get("total_calls", 0))
elements["metrics_prompt"].metric("Prompt Tokens", metrics.get("total_prompt_tokens", 0))
@ -160,10 +143,6 @@ def _update_dashboard_sidebar(
def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[str, object]:
elements = st.session_state.get("dashboard_elements")
if elements:
return elements
metrics_container.header("系统监控")
col_a, col_b, col_c = metrics_container.columns(3)
metrics_calls = col_a.empty()
@ -184,7 +163,6 @@ def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[s
"model_distribution": model_distribution,
"decisions_list": decisions_list,
}
st.session_state["dashboard_elements"] = elements
return elements
def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]:
@ -835,12 +813,20 @@ def render_backtest() -> None:
action_values.append(action_val)
run_decision_env = st.button("执行单次调参", key="run_decision_env_button")
just_finished_single = False
if run_decision_env:
if not selected_agents:
st.warning("请至少选择一个代理进行调参。")
elif not range_valid:
st.error("请确保所有代理的最大权重大于最小权重。")
else:
LOGGER.info(
"离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s",
selected_agents,
action_values,
disable_departments,
extra=LOG_EXTRA,
)
baseline_weights = cfg.agent_weights.as_dict()
for agent in agent_objects:
baseline_weights.setdefault(agent.name, 1.0)
@ -869,74 +855,126 @@ def render_backtest() -> None:
disable_departments=disable_departments,
)
env.reset()
LOGGER.debug(
"离线调参(单次)启动 DecisionEnvcfg=%s 参数维度=%s",
bt_cfg_env,
len(specs),
extra=LOG_EXTRA,
)
with st.spinner("正在执行离线调参……"):
try:
observation, reward, done, info = env.step(action_values)
LOGGER.info(
"离线调参单次完成obs=%s reward=%.4f done=%s",
observation,
reward,
done,
extra=LOG_EXTRA,
)
except Exception as exc: # noqa: BLE001
LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA)
st.error(f"离线调参失败:{exc}")
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
else:
if observation.get("failure"):
st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。")
st.json(observation)
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
else:
st.success("离线调参完成")
col_metrics = st.columns(4)
col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}")
col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}")
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
col_metrics[3].metric("奖励", f"{reward:+.4f}")
st.write("调参后权重:")
weights_dict = info.get("weights", {})
st.json(weights_dict)
resolved_experiment_id = experiment_id or str(uuid.uuid4())
resolved_strategy = strategy_label or "DecisionEnv"
action_payload = {
name: value
for name, value in zip(selected_agents, action_values)
}
metrics_payload = dict(observation)
metrics_payload["reward"] = reward
log_success = False
try:
log_tuning_result(
experiment_id=experiment_id or str(uuid.uuid4()),
strategy=strategy_label or "DecisionEnv",
experiment_id=resolved_experiment_id,
strategy=resolved_strategy,
action=action_payload,
reward=reward,
metrics=metrics_payload,
weights=weights_dict,
weights=info.get("weights", {}),
)
st.caption("调参结果已写入 tuning_results 表。")
except Exception: # noqa: BLE001
LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA)
else:
log_success = True
LOGGER.info(
"离线调参单次日志写入成功experiment=%s strategy=%s",
resolved_experiment_id,
resolved_strategy,
extra=LOG_EXTRA,
)
st.session_state[_DECISION_ENV_SINGLE_RESULT_KEY] = {
"observation": dict(observation),
"reward": float(reward),
"weights": info.get("weights", {}),
"nav_series": info.get("nav_series"),
"trades": info.get("trades"),
"selected_agents": list(selected_agents),
"action_values": list(action_values),
"experiment_id": resolved_experiment_id,
"strategy_label": resolved_strategy,
"logged": log_success,
}
just_finished_single = True
single_result = st.session_state.get(_DECISION_ENV_SINGLE_RESULT_KEY)
if single_result:
if just_finished_single:
st.success("离线调参完成")
else:
st.success("离线调参结果(最近一次运行)")
st.caption(
f"实验 ID{single_result.get('experiment_id', '-') } | 策略:{single_result.get('strategy_label', 'DecisionEnv')}"
)
observation = single_result.get("observation", {})
reward = float(single_result.get("reward", 0.0))
col_metrics = st.columns(4)
col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}")
col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}")
col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}")
col_metrics[3].metric("奖励", f"{reward:+.4f}")
if weights_dict:
if st.button(
"保存这些权重为默认配置",
key="save_decision_env_weights_single",
):
try:
cfg.agent_weights.update_from_dict(weights_dict)
save_config(cfg)
except Exception as exc: # noqa: BLE001
LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"写入配置失败:{exc}")
else:
st.success("代理权重已写入 config.json")
weights_dict = single_result.get("weights") or {}
if weights_dict:
st.write("调参后权重:")
st.json(weights_dict)
if st.button("保存这些权重为默认配置", key="save_decision_env_weights_single"):
try:
cfg.agent_weights.update_from_dict(weights_dict)
save_config(cfg)
except Exception as exc: # noqa: BLE001
LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"写入配置失败:{exc}")
else:
st.success("代理权重已写入 config.json")
nav_series = info.get("nav_series")
if nav_series:
try:
nav_df = pd.DataFrame(nav_series)
if {"trade_date", "nav"}.issubset(nav_df.columns):
nav_df = nav_df.sort_values("trade_date")
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"])
st.line_chart(nav_df.set_index("trade_date")["nav"], height=220)
except Exception: # noqa: BLE001
LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA)
trades = info.get("trades")
if trades:
st.write("成交记录:")
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
if single_result.get("logged"):
st.caption("调参结果已写入 tuning_results 表。")
nav_series = single_result.get("nav_series") or []
if nav_series:
try:
nav_df = pd.DataFrame(nav_series)
if {"trade_date", "nav"}.issubset(nav_df.columns):
nav_df = nav_df.sort_values("trade_date")
nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"])
st.line_chart(nav_df.set_index("trade_date")["nav"], height=220)
except Exception: # noqa: BLE001
LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA)
trades = single_result.get("trades") or []
if trades:
st.write("成交记录:")
st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch')
if st.button("清除单次调参结果", key="clear_decision_env_single"):
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
st.success("已清除单次调参结果缓存。")
st.divider()
st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。")
@ -954,16 +992,28 @@ def render_backtest() -> None:
key="decision_env_batch_actions",
)
run_batch = st.button("批量执行调参", key="run_decision_env_batch")
batch_just_ran = False
if run_batch:
if not selected_agents:
st.warning("请先选择调参代理。")
elif not range_valid:
st.error("请确保所有代理的最大权重大于最小权重。")
else:
LOGGER.info(
"离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s",
selected_agents,
disable_departments,
extra=LOG_EXTRA,
)
lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()]
if not lines:
st.warning("请在文本框中输入至少一组动作。")
else:
LOGGER.debug(
"离线调参(批量)原始输入=%s",
lines,
extra=LOG_EXTRA,
)
parsed_actions: List[List[float]] = []
for line in lines:
try:
@ -978,6 +1028,11 @@ def render_backtest() -> None:
break
parsed_actions.append(values)
if parsed_actions:
LOGGER.info(
"离线调参(批量)解析动作成功,数量=%s",
len(parsed_actions),
extra=LOG_EXTRA,
)
baseline_weights = cfg.agent_weights.as_dict()
for agent in agent_objects:
baseline_weights.setdefault(agent.name, 1.0)
@ -1006,6 +1061,14 @@ def render_backtest() -> None:
disable_departments=disable_departments,
)
results: List[Dict[str, object]] = []
resolved_experiment_id = experiment_id or str(uuid.uuid4())
resolved_strategy = strategy_label or "DecisionEnv"
LOGGER.debug(
"离线调参(批量)启动 DecisionEnvcfg=%s 动作组=%s",
bt_cfg_env,
len(parsed_actions),
extra=LOG_EXTRA,
)
with st.spinner("正在批量执行调参……"):
for idx, action_vals in enumerate(parsed_actions, start=1):
env.reset()
@ -1032,6 +1095,13 @@ def render_backtest() -> None:
}
)
else:
LOGGER.info(
"离线调参(批量)第 %s 组完成reward=%.4f obs=%s",
idx,
reward,
observation,
extra=LOG_EXTRA,
)
action_payload = {
name: value
for name, value in zip(selected_agents, action_vals)
@ -1041,8 +1111,8 @@ def render_backtest() -> None:
weights_payload = info.get("weights", {})
try:
log_tuning_result(
experiment_id=experiment_id or str(uuid.uuid4()),
strategy=strategy_label or "DecisionEnv",
experiment_id=resolved_experiment_id,
strategy=resolved_strategy,
action=action_payload,
reward=reward,
metrics=metrics_payload,
@ -1062,46 +1132,74 @@ def render_backtest() -> None:
"权重": weights_payload,
}
)
if results:
st.write("批量调参结果:")
results_df = pd.DataFrame(results)
st.dataframe(results_df, hide_index=True, width='stretch')
selectable = [
st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = {
"results": results,
"selectable": [
row
for row in results
if row.get("状态") == "ok" and row.get("权重")
]
if selectable:
option_labels = [
f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}"
for row in selectable
]
selected_label = st.selectbox(
"选择要保存的记录",
option_labels,
key="decision_env_batch_select",
)
selected_row = None
for label, row in zip(option_labels, selectable):
if label == selected_label:
selected_row = row
break
if selected_row and st.button(
"保存所选权重为默认配置",
key="save_decision_env_weights_batch",
):
try:
cfg.agent_weights.update_from_dict(selected_row.get("权重", {}))
save_config(cfg)
except Exception as exc: # noqa: BLE001
LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"写入配置失败:{exc}")
else:
st.success(
f"已将序号 {selected_row['序号']} 的权重写入 config.json"
)
else:
st.caption("暂无成功的结果可供保存。")
],
"experiment_id": resolved_experiment_id,
"strategy_label": resolved_strategy,
}
batch_just_ran = True
LOGGER.info(
"离线调参(批量)执行结束,总结果条数=%s",
len(results),
extra=LOG_EXTRA,
)
batch_state = st.session_state.get(_DECISION_ENV_BATCH_RESULTS_KEY)
if batch_state:
results = batch_state.get("results") or []
if results:
if batch_just_ran:
st.success("批量调参完成")
else:
st.success("批量调参结果(最近一次运行)")
st.caption(
f"实验 ID{batch_state.get('experiment_id', '-') } | 策略:{batch_state.get('strategy_label', 'DecisionEnv')}"
)
results_df = pd.DataFrame(results)
st.write("批量调参结果:")
st.dataframe(results_df, hide_index=True, width='stretch')
selectable = batch_state.get("selectable") or []
if selectable:
option_labels = [
f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}"
for row in selectable
]
selected_label = st.selectbox(
"选择要保存的记录",
option_labels,
key="decision_env_batch_select",
)
selected_row = None
for label, row in zip(option_labels, selectable):
if label == selected_label:
selected_row = row
break
if selected_row and st.button(
"保存所选权重为默认配置",
key="save_decision_env_weights_batch",
):
try:
cfg.agent_weights.update_from_dict(selected_row.get("权重", {}))
save_config(cfg)
except Exception as exc: # noqa: BLE001
LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"写入配置失败:{exc}")
else:
st.success(
f"已将序号 {selected_row['序号']} 的权重写入 config.json"
)
else:
st.caption("暂无成功的结果可供保存。")
else:
st.caption("批量调参在最近一次执行中未产生结果。")
if st.button("清除批量调参结果", key="clear_decision_env_batch"):
st.session_state.pop(_DECISION_ENV_BATCH_RESULTS_KEY, None)
st.session_state.pop("decision_env_batch_select", None)
st.success("已清除批量调参结果缓存。")
def render_settings() -> None:
@ -1673,7 +1771,7 @@ def render_tests() -> None:
]
)
candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10))
st.plotly_chart(candle_fig, width='stretch')
st.plotly_chart(candle_fig, use_container_width=True)
vol_fig = px.bar(
df_reset,
@ -1683,7 +1781,7 @@ def render_tests() -> None:
title="成交量",
)
vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
st.plotly_chart(vol_fig, width='stretch')
st.plotly_chart(vol_fig, use_container_width=True)
amt_fig = px.bar(
df_reset,
@ -1693,7 +1791,7 @@ def render_tests() -> None:
title="成交额",
)
amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10))
st.plotly_chart(amt_fig, width='stretch')
st.plotly_chart(amt_fig, use_container_width=True)
df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str)
box_fig = px.box(
@ -1704,7 +1802,7 @@ def render_tests() -> None:
title="月度收盘价分布",
)
box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10))
st.plotly_chart(box_fig, width='stretch')
st.plotly_chart(box_fig, use_container_width=True)
st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。")
st.dataframe(df_reset.tail(20), width='stretch')