diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 874a133..7fb6aa8 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -2,7 +2,6 @@ from __future__ import annotations import sys -import time from dataclasses import asdict from datetime import date, datetime, timedelta from pathlib import Path @@ -62,58 +61,42 @@ from app.utils.tuning import log_tuning_result LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "ui"} -_SIDEBAR_THROTTLE_SECONDS = 0.75 - - -def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None: - _update_dashboard_sidebar(metrics, throttled=True) +_DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result" +_DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results" +_DASHBOARD_CONTAINERS: Optional[tuple[object, object]] = None +_DASHBOARD_ELEMENTS: Optional[Dict[str, object]] = None def render_global_dashboard() -> None: """Render a persistent sidebar with realtime LLM stats and recent decisions.""" + global _DASHBOARD_CONTAINERS + global _DASHBOARD_ELEMENTS + metrics_container = st.sidebar.container() decisions_container = st.sidebar.container() - st.session_state["dashboard_containers"] = (metrics_container, decisions_container) - _ensure_dashboard_elements(metrics_container, decisions_container) - if not st.session_state.get("dashboard_listener_registered"): - register_llm_metrics_listener(_sidebar_metrics_listener) - st.session_state["dashboard_listener_registered"] = True + _DASHBOARD_CONTAINERS = (metrics_container, decisions_container) + _DASHBOARD_ELEMENTS = _ensure_dashboard_elements(metrics_container, decisions_container) _update_dashboard_sidebar() def _update_dashboard_sidebar( metrics: Optional[Dict[str, object]] = None, - *, - throttled: bool = False, ) -> None: - containers = st.session_state.get("dashboard_containers") + global _DASHBOARD_CONTAINERS + global _DASHBOARD_ELEMENTS + + containers = _DASHBOARD_CONTAINERS if not containers: return metrics_container, decisions_container = containers - elements = st.session_state.get("dashboard_elements") + elements = _DASHBOARD_ELEMENTS if elements is None: elements = _ensure_dashboard_elements(metrics_container, decisions_container) - - if throttled: - now = time.monotonic() - last_update = st.session_state.get("dashboard_last_update_ts", 0.0) - if now - last_update < _SIDEBAR_THROTTLE_SECONDS: - if metrics is not None: - st.session_state["dashboard_pending_metrics"] = metrics - return - st.session_state["dashboard_last_update_ts"] = now - else: - st.session_state["dashboard_last_update_ts"] = time.monotonic() + _DASHBOARD_ELEMENTS = elements if metrics is None: - metrics = st.session_state.pop("dashboard_pending_metrics", None) - if metrics is None: - metrics = snapshot_llm_metrics() - else: - st.session_state.pop("dashboard_pending_metrics", None) - - metrics = metrics or snapshot_llm_metrics() + metrics = snapshot_llm_metrics() elements["metrics_calls"].metric("LLM 调用", metrics.get("total_calls", 0)) elements["metrics_prompt"].metric("Prompt Tokens", metrics.get("total_prompt_tokens", 0)) @@ -160,10 +143,6 @@ def _update_dashboard_sidebar( def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[str, object]: - elements = st.session_state.get("dashboard_elements") - if elements: - return elements - metrics_container.header("系统监控") col_a, col_b, col_c = metrics_container.columns(3) metrics_calls = col_a.empty() @@ -184,7 +163,6 @@ def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[s "model_distribution": model_distribution, "decisions_list": decisions_list, } - st.session_state["dashboard_elements"] = elements return elements def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]: @@ -835,12 +813,20 @@ def render_backtest() -> None: action_values.append(action_val) run_decision_env = st.button("执行单次调参", key="run_decision_env_button") + just_finished_single = False if run_decision_env: if not selected_agents: st.warning("请至少选择一个代理进行调参。") elif not range_valid: st.error("请确保所有代理的最大权重大于最小权重。") else: + LOGGER.info( + "离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s", + selected_agents, + action_values, + disable_departments, + extra=LOG_EXTRA, + ) baseline_weights = cfg.agent_weights.as_dict() for agent in agent_objects: baseline_weights.setdefault(agent.name, 1.0) @@ -869,74 +855,126 @@ def render_backtest() -> None: disable_departments=disable_departments, ) env.reset() + LOGGER.debug( + "离线调参(单次)启动 DecisionEnv:cfg=%s 参数维度=%s", + bt_cfg_env, + len(specs), + extra=LOG_EXTRA, + ) with st.spinner("正在执行离线调参……"): try: observation, reward, done, info = env.step(action_values) + LOGGER.info( + "离线调参(单次)完成,obs=%s reward=%.4f done=%s", + observation, + reward, + done, + extra=LOG_EXTRA, + ) except Exception as exc: # noqa: BLE001 LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA) st.error(f"离线调参失败:{exc}") + st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None) else: if observation.get("failure"): st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。") st.json(observation) + st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None) else: - st.success("离线调参完成") - col_metrics = st.columns(4) - col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}") - col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}") - col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}") - col_metrics[3].metric("奖励", f"{reward:+.4f}") - - st.write("调参后权重:") - weights_dict = info.get("weights", {}) - st.json(weights_dict) + resolved_experiment_id = experiment_id or str(uuid.uuid4()) + resolved_strategy = strategy_label or "DecisionEnv" action_payload = { name: value for name, value in zip(selected_agents, action_values) } metrics_payload = dict(observation) metrics_payload["reward"] = reward + log_success = False try: log_tuning_result( - experiment_id=experiment_id or str(uuid.uuid4()), - strategy=strategy_label or "DecisionEnv", + experiment_id=resolved_experiment_id, + strategy=resolved_strategy, action=action_payload, reward=reward, metrics=metrics_payload, - weights=weights_dict, + weights=info.get("weights", {}), ) - st.caption("调参结果已写入 tuning_results 表。") except Exception: # noqa: BLE001 LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA) + else: + log_success = True + LOGGER.info( + "离线调参(单次)日志写入成功:experiment=%s strategy=%s", + resolved_experiment_id, + resolved_strategy, + extra=LOG_EXTRA, + ) + st.session_state[_DECISION_ENV_SINGLE_RESULT_KEY] = { + "observation": dict(observation), + "reward": float(reward), + "weights": info.get("weights", {}), + "nav_series": info.get("nav_series"), + "trades": info.get("trades"), + "selected_agents": list(selected_agents), + "action_values": list(action_values), + "experiment_id": resolved_experiment_id, + "strategy_label": resolved_strategy, + "logged": log_success, + } + just_finished_single = True + single_result = st.session_state.get(_DECISION_ENV_SINGLE_RESULT_KEY) + if single_result: + if just_finished_single: + st.success("离线调参完成") + else: + st.success("离线调参结果(最近一次运行)") + st.caption( + f"实验 ID:{single_result.get('experiment_id', '-') } | 策略:{single_result.get('strategy_label', 'DecisionEnv')}" + ) + observation = single_result.get("observation", {}) + reward = float(single_result.get("reward", 0.0)) + col_metrics = st.columns(4) + col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}") + col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}") + col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}") + col_metrics[3].metric("奖励", f"{reward:+.4f}") - if weights_dict: - if st.button( - "保存这些权重为默认配置", - key="save_decision_env_weights_single", - ): - try: - cfg.agent_weights.update_from_dict(weights_dict) - save_config(cfg) - except Exception as exc: # noqa: BLE001 - LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)}) - st.error(f"写入配置失败:{exc}") - else: - st.success("代理权重已写入 config.json") + weights_dict = single_result.get("weights") or {} + if weights_dict: + st.write("调参后权重:") + st.json(weights_dict) + if st.button("保存这些权重为默认配置", key="save_decision_env_weights_single"): + try: + cfg.agent_weights.update_from_dict(weights_dict) + save_config(cfg) + except Exception as exc: # noqa: BLE001 + LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)}) + st.error(f"写入配置失败:{exc}") + else: + st.success("代理权重已写入 config.json") - nav_series = info.get("nav_series") - if nav_series: - try: - nav_df = pd.DataFrame(nav_series) - if {"trade_date", "nav"}.issubset(nav_df.columns): - nav_df = nav_df.sort_values("trade_date") - nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"]) - st.line_chart(nav_df.set_index("trade_date")["nav"], height=220) - except Exception: # noqa: BLE001 - LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA) - trades = info.get("trades") - if trades: - st.write("成交记录:") - st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch') + if single_result.get("logged"): + st.caption("调参结果已写入 tuning_results 表。") + + nav_series = single_result.get("nav_series") or [] + if nav_series: + try: + nav_df = pd.DataFrame(nav_series) + if {"trade_date", "nav"}.issubset(nav_df.columns): + nav_df = nav_df.sort_values("trade_date") + nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"]) + st.line_chart(nav_df.set_index("trade_date")["nav"], height=220) + except Exception: # noqa: BLE001 + LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA) + + trades = single_result.get("trades") or [] + if trades: + st.write("成交记录:") + st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch') + + if st.button("清除单次调参结果", key="clear_decision_env_single"): + st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None) + st.success("已清除单次调参结果缓存。") st.divider() st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。") @@ -954,16 +992,28 @@ def render_backtest() -> None: key="decision_env_batch_actions", ) run_batch = st.button("批量执行调参", key="run_decision_env_batch") + batch_just_ran = False if run_batch: if not selected_agents: st.warning("请先选择调参代理。") elif not range_valid: st.error("请确保所有代理的最大权重大于最小权重。") else: + LOGGER.info( + "离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s", + selected_agents, + disable_departments, + extra=LOG_EXTRA, + ) lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()] if not lines: st.warning("请在文本框中输入至少一组动作。") else: + LOGGER.debug( + "离线调参(批量)原始输入=%s", + lines, + extra=LOG_EXTRA, + ) parsed_actions: List[List[float]] = [] for line in lines: try: @@ -978,6 +1028,11 @@ def render_backtest() -> None: break parsed_actions.append(values) if parsed_actions: + LOGGER.info( + "离线调参(批量)解析动作成功,数量=%s", + len(parsed_actions), + extra=LOG_EXTRA, + ) baseline_weights = cfg.agent_weights.as_dict() for agent in agent_objects: baseline_weights.setdefault(agent.name, 1.0) @@ -1006,6 +1061,14 @@ def render_backtest() -> None: disable_departments=disable_departments, ) results: List[Dict[str, object]] = [] + resolved_experiment_id = experiment_id or str(uuid.uuid4()) + resolved_strategy = strategy_label or "DecisionEnv" + LOGGER.debug( + "离线调参(批量)启动 DecisionEnv:cfg=%s 动作组=%s", + bt_cfg_env, + len(parsed_actions), + extra=LOG_EXTRA, + ) with st.spinner("正在批量执行调参……"): for idx, action_vals in enumerate(parsed_actions, start=1): env.reset() @@ -1032,6 +1095,13 @@ def render_backtest() -> None: } ) else: + LOGGER.info( + "离线调参(批量)第 %s 组完成,reward=%.4f obs=%s", + idx, + reward, + observation, + extra=LOG_EXTRA, + ) action_payload = { name: value for name, value in zip(selected_agents, action_vals) @@ -1041,8 +1111,8 @@ def render_backtest() -> None: weights_payload = info.get("weights", {}) try: log_tuning_result( - experiment_id=experiment_id or str(uuid.uuid4()), - strategy=strategy_label or "DecisionEnv", + experiment_id=resolved_experiment_id, + strategy=resolved_strategy, action=action_payload, reward=reward, metrics=metrics_payload, @@ -1062,46 +1132,74 @@ def render_backtest() -> None: "权重": weights_payload, } ) - if results: - st.write("批量调参结果:") - results_df = pd.DataFrame(results) - st.dataframe(results_df, hide_index=True, width='stretch') - selectable = [ + st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = { + "results": results, + "selectable": [ row for row in results if row.get("状态") == "ok" and row.get("权重") - ] - if selectable: - option_labels = [ - f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}" - for row in selectable - ] - selected_label = st.selectbox( - "选择要保存的记录", - option_labels, - key="decision_env_batch_select", - ) - selected_row = None - for label, row in zip(option_labels, selectable): - if label == selected_label: - selected_row = row - break - if selected_row and st.button( - "保存所选权重为默认配置", - key="save_decision_env_weights_batch", - ): - try: - cfg.agent_weights.update_from_dict(selected_row.get("权重", {})) - save_config(cfg) - except Exception as exc: # noqa: BLE001 - LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)}) - st.error(f"写入配置失败:{exc}") - else: - st.success( - f"已将序号 {selected_row['序号']} 的权重写入 config.json" - ) - else: - st.caption("暂无成功的结果可供保存。") + ], + "experiment_id": resolved_experiment_id, + "strategy_label": resolved_strategy, + } + batch_just_ran = True + LOGGER.info( + "离线调参(批量)执行结束,总结果条数=%s", + len(results), + extra=LOG_EXTRA, + ) + batch_state = st.session_state.get(_DECISION_ENV_BATCH_RESULTS_KEY) + if batch_state: + results = batch_state.get("results") or [] + if results: + if batch_just_ran: + st.success("批量调参完成") + else: + st.success("批量调参结果(最近一次运行)") + st.caption( + f"实验 ID:{batch_state.get('experiment_id', '-') } | 策略:{batch_state.get('strategy_label', 'DecisionEnv')}" + ) + results_df = pd.DataFrame(results) + st.write("批量调参结果:") + st.dataframe(results_df, hide_index=True, width='stretch') + selectable = batch_state.get("selectable") or [] + if selectable: + option_labels = [ + f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}" + for row in selectable + ] + selected_label = st.selectbox( + "选择要保存的记录", + option_labels, + key="decision_env_batch_select", + ) + selected_row = None + for label, row in zip(option_labels, selectable): + if label == selected_label: + selected_row = row + break + if selected_row and st.button( + "保存所选权重为默认配置", + key="save_decision_env_weights_batch", + ): + try: + cfg.agent_weights.update_from_dict(selected_row.get("权重", {})) + save_config(cfg) + except Exception as exc: # noqa: BLE001 + LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)}) + st.error(f"写入配置失败:{exc}") + else: + st.success( + f"已将序号 {selected_row['序号']} 的权重写入 config.json" + ) + else: + st.caption("暂无成功的结果可供保存。") + else: + st.caption("批量调参在最近一次执行中未产生结果。") + if st.button("清除批量调参结果", key="clear_decision_env_batch"): + st.session_state.pop(_DECISION_ENV_BATCH_RESULTS_KEY, None) + st.session_state.pop("decision_env_batch_select", None) + st.success("已清除批量调参结果缓存。") def render_settings() -> None: @@ -1673,7 +1771,7 @@ def render_tests() -> None: ] ) candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10)) - st.plotly_chart(candle_fig, width='stretch') + st.plotly_chart(candle_fig, use_container_width=True) vol_fig = px.bar( df_reset, @@ -1683,7 +1781,7 @@ def render_tests() -> None: title="成交量", ) vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10)) - st.plotly_chart(vol_fig, width='stretch') + st.plotly_chart(vol_fig, use_container_width=True) amt_fig = px.bar( df_reset, @@ -1693,7 +1791,7 @@ def render_tests() -> None: title="成交额", ) amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10)) - st.plotly_chart(amt_fig, width='stretch') + st.plotly_chart(amt_fig, use_container_width=True) df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str) box_fig = px.box( @@ -1704,7 +1802,7 @@ def render_tests() -> None: title="月度收盘价分布", ) box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10)) - st.plotly_chart(box_fig, width='stretch') + st.plotly_chart(box_fig, use_container_width=True) st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。") st.dataframe(df_reset.tail(20), width='stretch')