diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index fba67d3..342d1d4 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -58,6 +58,7 @@ from app.utils.portfolio import ( ) from app.agents.registry import default_agents from app.utils.tuning import log_tuning_result +from app.backtest.engine import BacktestEngine, PortfolioState LOGGER = get_logger(__name__) @@ -67,6 +68,9 @@ _DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results" _DASHBOARD_CONTAINERS: Optional[tuple[object, object]] = None _DASHBOARD_ELEMENTS: Optional[Dict[str, object]] = None _SIDEBAR_LISTENER_ATTACHED = False +# ADD: simple in-memory cache for provider model discovery +_MODEL_CACHE: Dict[str, Dict[str, object]] = {} +_CACHE_TTL_SECONDS = 300 def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None: @@ -210,6 +214,16 @@ def _discover_provider_models(provider: LLMProvider, base_override: str = "", ap timeout = float(provider.default_timeout or 30.0) mode = provider.mode or ("ollama" if provider.key == "ollama" else "openai") + # ADD: simple cache by provider+base URL + cache_key = f"{provider.key}|{base_url}" + now = datetime.now() + cached = _MODEL_CACHE.get(cache_key) + if cached: + ts = cached.get("ts") + if isinstance(ts, float) and (now.timestamp() - ts) < _CACHE_TTL_SECONDS: + models = list(cached.get("models") or []) + return models, None + try: if mode == "ollama": url = base_url.rstrip('/') + "/api/tags" @@ -221,6 +235,7 @@ def _discover_provider_models(provider: LLMProvider, base_override: str = "", ap name = item.get("name") or item.get("model") or item.get("tag") if name: models.append(str(name).strip()) + _MODEL_CACHE[cache_key] = {"ts": now.timestamp(), "models": sorted(set(models))} return sorted(set(models)), None api_key = (api_override or provider.api_key or "").strip() @@ -239,6 +254,7 @@ def _discover_provider_models(provider: LLMProvider, base_override: str = "", ap for item in payload.get("data", []) if item.get("id") ] + _MODEL_CACHE[cache_key] = {"ts": now.timestamp(), "models": sorted(set(models))} return sorted(set(models)), None except RequestException as exc: # noqa: BLE001 return [], f"HTTP 错误:{exc}" @@ -345,7 +361,7 @@ def render_today_plan() -> None: if latest_trade_date: st.caption(f"最新交易日:{latest_trade_date.isoformat()}(统计数据请见左侧系统监控)") else: - st.caption("统计与决策概览现已移至左侧“系统监控”侧栏。") + st.caption("统计与决策概览现已移至左侧'系统监控'侧栏。") try: with db_session(read_only=True) as conn: date_rows = conn.execute( @@ -384,6 +400,8 @@ def render_today_plan() -> None: return ts_code = st.selectbox("标的", symbols, index=0) + # ADD: batch selection for re-evaluation + batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[]) with db_session(read_only=True) as conn: rows = conn.execute( @@ -523,7 +541,16 @@ def render_today_plan() -> None: st.subheader("部门意见") if dept_records: - dept_df = pd.DataFrame(dept_records) + # ADD: keyword filter for department summaries + keyword = st.text_input("筛选摘要/信号关键词", value="") + filtered = dept_records + if keyword.strip(): + kw = keyword.strip() + filtered = [ + item for item in dept_records + if kw in str(item.get("摘要", "")) or kw in str(item.get("核心信号", "")) + ] + dept_df = pd.DataFrame(filtered) st.dataframe(dept_df, width='stretch', hide_index=True) for code, details in dept_details.items(): with st.expander(f"{code} 补充详情", expanded=False): @@ -636,6 +663,122 @@ def render_today_plan() -> None: st.caption("数据来源:agent_utils、investment_pool、portfolio_positions、portfolio_trades、portfolio_snapshots。") + st.divider() + st.subheader("策略重评估") + st.caption("对当前选中的交易日与标的,立即触发一次策略评估并回写 agent_utils。") + cols_re = st.columns([1,1]) + if cols_re[0].button("对该标的重评估", key="reevaluate_current_symbol"): + with st.spinner("正在重评估..."): + try: + trade_date_obj = None + try: + trade_date_obj = date.fromisoformat(str(trade_date)) + except Exception: + try: + trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() + except Exception: + pass + if trade_date_obj is None: + raise ValueError(f"无法解析交易日:{trade_date}") + # snapshot before + with db_session(read_only=True) as conn: + before_rows = conn.execute( + """ + SELECT agent, action, utils FROM agent_utils + WHERE trade_date = ? AND ts_code = ? + """, + (trade_date, ts_code), + ).fetchall() + before_map = {row["agent"]: (row["action"], row["utils"]) for row in before_rows} + cfg = BtConfig( + id="reeval_ui", + name="UI Re-evaluation", + start_date=trade_date_obj, + end_date=trade_date_obj, + universe=[ts_code], + params={}, + ) + engine = BacktestEngine(cfg) + state = PortfolioState() + _ = engine.simulate_day(trade_date_obj, state) + # compare after + with db_session(read_only=True) as conn: + after_rows = conn.execute( + """ + SELECT agent, action, utils FROM agent_utils + WHERE trade_date = ? AND ts_code = ? + """, + (trade_date, ts_code), + ).fetchall() + changes = [] + for row in after_rows: + agent = row["agent"] + new_action = row["action"] + old_action, _old_utils = before_map.get(agent, (None, None)) + if new_action != old_action: + changes.append({"代理": agent, "原动作": old_action, "新动作": new_action}) + if changes: + st.success("重评估完成,检测到动作变更:") + st.dataframe(pd.DataFrame(changes), hide_index=True, width='stretch') + else: + st.success("重评估完成,无动作变更。") + st.rerun() + except Exception as exc: # noqa: BLE001 + LOGGER.exception("重评估失败", extra=LOG_EXTRA) + st.error(f"重评估失败:{exc}") + if cols_re[1].button("批量重评估(所选)", key="reevaluate_batch", disabled=not batch_symbols): + with st.spinner("批量重评估执行中..."): + try: + trade_date_obj = None + try: + trade_date_obj = date.fromisoformat(str(trade_date)) + except Exception: + try: + trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() + except Exception: + pass + if trade_date_obj is None: + raise ValueError(f"无法解析交易日:{trade_date}") + progress = st.progress(0.0) + changes_all: List[Dict[str, object]] = [] + for idx, code in enumerate(batch_symbols, start=1): + with db_session(read_only=True) as conn: + before_rows = conn.execute( + "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", + (trade_date, code), + ).fetchall() + before_map = {row["agent"]: row["action"] for row in before_rows} + cfg = BtConfig( + id="reeval_ui_batch", + name="UI Batch Re-eval", + start_date=trade_date_obj, + end_date=trade_date_obj, + universe=[code], + params={}, + ) + engine = BacktestEngine(cfg) + state = PortfolioState() + _ = engine.simulate_day(trade_date_obj, state) + with db_session(read_only=True) as conn: + after_rows = conn.execute( + "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", + (trade_date, code), + ).fetchall() + for row in after_rows: + agent = row["agent"] + new_action = row["action"] + old_action = before_map.get(agent) + if new_action != old_action: + changes_all.append({"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action}) + progress.progress(idx / max(1, len(batch_symbols))) + st.success("批量重评估完成。") + if changes_all: + st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch') + st.rerun() + except Exception as exc: # noqa: BLE001 + LOGGER.exception("批量重评估失败", extra=LOG_EXTRA) + st.error(f"批量重评估失败:{exc}") + def render_backtest() -> None: LOGGER.info("渲染回测页面", extra=LOG_EXTRA) @@ -1237,6 +1380,84 @@ def render_backtest() -> None: st.session_state.pop("decision_env_batch_select", None) st.success("已清除批量调参结果缓存。") + # ADD: Comparison view for multiple backtest configurations + with st.expander("回测结果对比", expanded=False): + st.caption("从历史回测配置中选择多个进行净值曲线与指标对比。") + normalize_to_one = st.checkbox("归一化到 1 起点", value=True) + use_log_y = st.checkbox("对数坐标", value=False) + metric_options = ["总收益", "最大回撤", "交易数", "平均换手", "风险事件"] + selected_metrics = st.multiselect("显示指标列", metric_options, default=metric_options) + try: + with db_session(read_only=True) as conn: + cfg_rows = conn.execute( + "SELECT id, name FROM bt_config ORDER BY rowid DESC LIMIT 50" + ).fetchall() + except Exception: # noqa: BLE001 + LOGGER.exception("读取 bt_config 失败", extra=LOG_EXTRA) + cfg_rows = [] + cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows] + selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2]) + selected_ids = [label.split(" | ")[0].strip() for label in selected_labels] + if selected_ids: + try: + with db_session(read_only=True) as conn: + nav_df = pd.read_sql_query( + "SELECT cfg_id, trade_date, nav FROM bt_nav WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))), + conn, + params=tuple(selected_ids), + ) + rpt_df = pd.read_sql_query( + "SELECT cfg_id, summary FROM bt_report WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))), + conn, + params=tuple(selected_ids), + ) + except Exception: # noqa: BLE001 + LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA) + st.error("读取回测结果失败") + nav_df = pd.DataFrame() + rpt_df = pd.DataFrame() + if not nav_df.empty: + try: + nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce") + pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav") + if normalize_to_one: + pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s) + import plotly.graph_objects as go + fig = go.Figure() + for col in pivot.columns: + fig.add_trace(go.Scatter(x=pivot.index, y=pivot[col], mode="lines", name=str(col))) + fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10)) + if use_log_y: + fig.update_yaxes(type="log") + st.plotly_chart(fig, use_container_width=True) + except Exception: # noqa: BLE001 + LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA) + if not rpt_df.empty: + try: + metrics_rows: List[Dict[str, object]] = [] + for _, row in rpt_df.iterrows(): + cfg_id = row["cfg_id"] + try: + summary = json.loads(row["summary"]) if isinstance(row["summary"], str) else (row["summary"] or {}) + except json.JSONDecodeError: + summary = {} + record = { + "cfg_id": cfg_id, + "总收益": summary.get("total_return"), + "最大回撤": summary.get("max_drawdown"), + "交易数": summary.get("trade_count"), + "平均换手": summary.get("avg_turnover"), + "风险事件": summary.get("risk_events"), + } + metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)}) + if metrics_rows: + dfm = pd.DataFrame(metrics_rows) + st.dataframe(dfm, hide_index=True, width='stretch') + except Exception: # noqa: BLE001 + LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA) + else: + st.info("请选择至少一个配置进行对比。") + def render_settings() -> None: LOGGER.info("渲染设置页面", extra=LOG_EXTRA)