"""Streamlit UI scaffold for the investment assistant.""" from __future__ import annotations import sys from dataclasses import asdict from datetime import date, datetime, timedelta from pathlib import Path from typing import Dict, List, Optional ROOT = Path(__file__).resolve().parents[2] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) import json from datetime import datetime import uuid import pandas as pd import plotly.express as px import plotly.graph_objects as go import requests from requests.exceptions import RequestException import streamlit as st from app.agents.base import AgentContext from app.agents.game import Decision from app.backtest.engine import BtConfig, run_backtest from app.backtest.decision_env import DecisionEnv, ParameterSpec from app.data.schema import initialize_database from app.ingest.checker import run_boot_check from app.ingest.tushare import FetchJob, run_ingestion from app.llm.client import llm_config_snapshot, run_llm from app.llm.metrics import ( recent_decisions as llm_recent_decisions, register_listener as register_llm_metrics_listener, reset as reset_llm_metrics, snapshot as snapshot_llm_metrics, ) from app.utils import alerts from app.utils.config import ( ALLOWED_LLM_STRATEGIES, DEFAULT_LLM_BASE_URLS, DEFAULT_LLM_MODEL_OPTIONS, DEFAULT_LLM_MODELS, DepartmentSettings, LLMEndpoint, LLMProvider, get_config, save_config, ) from app.utils.db import db_session from app.utils.logging import get_logger from app.utils.portfolio import ( get_latest_snapshot, list_investment_pool, list_positions, list_recent_trades, ) from app.agents.registry import default_agents from app.utils.tuning import log_tuning_result from app.backtest.engine import BacktestEngine, PortfolioState LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "ui"} _DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result" _DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results" _DASHBOARD_CONTAINERS: Optional[tuple[object, object]] = None _DASHBOARD_ELEMENTS: Optional[Dict[str, object]] = None _SIDEBAR_LISTENER_ATTACHED = False # ADD: simple in-memory cache for provider model discovery _MODEL_CACHE: Dict[str, Dict[str, object]] = {} _CACHE_TTL_SECONDS = 300 _WARNINGS_CONTAINER = None _WARNINGS_PLACEHOLDER = None # ADD: query param helpers def _get_query_params() -> Dict[str, List[str]]: try: return dict(st.query_params) except Exception: return {} def _set_query_params(**kwargs: object) -> None: try: payload = {k: v for k, v in kwargs.items() if v is not None} if payload: st.query_params.update(payload) except Exception: pass def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None: try: _update_dashboard_sidebar(metrics) except Exception: # noqa: BLE001 LOGGER.debug("侧边栏监听器刷新失败", exc_info=True, extra=LOG_EXTRA) def render_global_dashboard() -> None: """Render a persistent sidebar with realtime LLM stats and recent decisions.""" global _DASHBOARD_CONTAINERS global _DASHBOARD_ELEMENTS global _SIDEBAR_LISTENER_ATTACHED global _WARNINGS_CONTAINER global _WARNINGS_PLACEHOLDER # ADD: warning badge on top warnings = alerts.get_warnings() badge = f" ({len(warnings)})" if warnings else "" st.sidebar.header(f"系统监控{badge}") metrics_container = st.sidebar.container() decisions_container = st.sidebar.container() _WARNINGS_CONTAINER = st.sidebar.container() _WARNINGS_PLACEHOLDER = st.sidebar.empty() _DASHBOARD_CONTAINERS = (metrics_container, decisions_container) _DASHBOARD_ELEMENTS = _ensure_dashboard_elements(metrics_container, decisions_container) if not _SIDEBAR_LISTENER_ATTACHED: register_llm_metrics_listener(_sidebar_metrics_listener) _SIDEBAR_LISTENER_ATTACHED = True _update_dashboard_sidebar() def _update_dashboard_sidebar( metrics: Optional[Dict[str, object]] = None, ) -> None: global _DASHBOARD_CONTAINERS global _DASHBOARD_ELEMENTS global _WARNINGS_CONTAINER global _WARNINGS_PLACEHOLDER containers = _DASHBOARD_CONTAINERS if not containers: return metrics_container, decisions_container = containers elements = _DASHBOARD_ELEMENTS if elements is None: elements = _ensure_dashboard_elements(metrics_container, decisions_container) _DASHBOARD_ELEMENTS = elements if metrics is None: metrics = snapshot_llm_metrics() elements["metrics_calls"].metric("LLM 调用", metrics.get("total_calls", 0)) elements["metrics_prompt"].metric("Prompt Tokens", metrics.get("total_prompt_tokens", 0)) elements["metrics_completion"].metric( "Completion Tokens", metrics.get("total_completion_tokens", 0) ) provider_calls = metrics.get("provider_calls", {}) model_calls = metrics.get("model_calls", {}) provider_placeholder = elements["provider_distribution"] provider_placeholder.empty() if provider_calls: provider_placeholder.json(provider_calls) else: provider_placeholder.info("暂无 Provider 分布数据。") model_placeholder = elements["model_distribution"] model_placeholder.empty() if model_calls: model_placeholder.json(model_calls) else: model_placeholder.info("暂无模型分布数据。") decisions = metrics.get("recent_decisions") or llm_recent_decisions(10) if decisions: lines = [] for record in reversed(decisions[-10:]): ts_code = record.get("ts_code") trade_date = record.get("trade_date") action = record.get("action") confidence = record.get("confidence", 0.0) summary = record.get("summary") line = f"**{trade_date} {ts_code}** → {action} (置信度 {confidence:.2f})" if summary: line += f"\n{summary}" lines.append(line) decisions_placeholder = elements["decisions_list"] decisions_placeholder.empty() decisions_placeholder.markdown("\n\n".join(lines), unsafe_allow_html=True) else: decisions_placeholder = elements["decisions_list"] decisions_placeholder.empty() decisions_placeholder.info("暂无决策记录。执行回测或实时评估后可在此查看。") # Render warnings section in-place (clear then write) if _WARNINGS_PLACEHOLDER is not None: _WARNINGS_PLACEHOLDER.empty() with _WARNINGS_PLACEHOLDER.container(): st.subheader("数据告警") warn_list = alerts.get_warnings() if warn_list: lines = [] for warning in warn_list[-10:]: detail = warning.get("detail") appendix = f" {detail}" if detail else "" lines.append( f"- **{warning['source']}** {warning['message']}{appendix}\n{warning['timestamp']}" ) st.markdown("\n".join(lines), unsafe_allow_html=True) btn_cols = st.columns([1,1]) if btn_cols[0].button("清除数据告警", key="clear_data_alerts_sibling"): alerts.clear_warnings() _update_dashboard_sidebar() try: st.download_button( "导出告警(JSON)", data=json.dumps(warn_list, ensure_ascii=False, indent=2), file_name="data_warnings.json", mime="application/json", key="dl_warnings_json_sibling", ) except Exception: pass else: st.info("暂无数据告警。") def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[str, object]: metrics_container.header("系统监控") col_a, col_b, col_c = metrics_container.columns(3) metrics_calls = col_a.empty() metrics_prompt = col_b.empty() metrics_completion = col_c.empty() distribution_expander = metrics_container.expander("调用分布", expanded=False) provider_distribution = distribution_expander.empty() model_distribution = distribution_expander.empty() decisions_container.subheader("最新决策") decisions_list = decisions_container.empty() elements = { "metrics_calls": metrics_calls, "metrics_prompt": metrics_prompt, "metrics_completion": metrics_completion, "provider_distribution": provider_distribution, "model_distribution": model_distribution, "decisions_list": decisions_list, } return elements def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]: """Attempt to query provider API and return available model ids.""" base_url = (base_override or provider.base_url or DEFAULT_LLM_BASE_URLS.get(provider.key, "")).strip() if not base_url: return [], "请先填写 Base URL" timeout = float(provider.default_timeout or 30.0) mode = provider.mode or ("ollama" if provider.key == "ollama" else "openai") # ADD: simple cache by provider+base URL cache_key = f"{provider.key}|{base_url}" now = datetime.now() cached = _MODEL_CACHE.get(cache_key) if cached: ts = cached.get("ts") if isinstance(ts, float) and (now.timestamp() - ts) < _CACHE_TTL_SECONDS: models = list(cached.get("models") or []) return models, None try: if mode == "ollama": url = base_url.rstrip('/') + "/api/tags" response = requests.get(url, timeout=timeout) response.raise_for_status() data = response.json() models = [] for item in data.get("models", []) or data.get("data", []): name = item.get("name") or item.get("model") or item.get("tag") if name: models.append(str(name).strip()) _MODEL_CACHE[cache_key] = {"ts": now.timestamp(), "models": sorted(set(models))} return sorted(set(models)), None api_key = (api_override or provider.api_key or "").strip() if not api_key: return [], "缺少 API Key" url = base_url.rstrip('/') + "/v1/models" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } response = requests.get(url, headers=headers, timeout=timeout) response.raise_for_status() payload = response.json() models = [ str(item.get("id")).strip() for item in payload.get("data", []) if item.get("id") ] _MODEL_CACHE[cache_key] = {"ts": now.timestamp(), "models": sorted(set(models))} return sorted(set(models)), None except RequestException as exc: # noqa: BLE001 return [], f"HTTP 错误:{exc}" except Exception as exc: # noqa: BLE001 return [], f"解析失败:{exc}" def _load_stock_options(limit: int = 500) -> list[str]: try: with db_session(read_only=True) as conn: rows = conn.execute( "SELECT ts_code, name FROM stock_basic WHERE list_status = 'L' ORDER BY ts_code" ).fetchall() except Exception: LOGGER.exception("加载股票列表失败", extra=LOG_EXTRA) return [] options: list[str] = [] for row in rows[:limit]: code = row["ts_code"] name = row["name"] or "" label = f"{code} | {name}" if name else code options.append(label) LOGGER.info("加载股票选项完成,数量=%s", len(options), extra=LOG_EXTRA) return options def _parse_ts_code(selection: str) -> str: return selection.split(' | ')[0].strip().upper() def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame: LOGGER.info( "加载行情数据:ts_code=%s start=%s end=%s", ts_code, start, end, extra=LOG_EXTRA, ) start_str = start.strftime('%Y%m%d') end_str = end.strftime('%Y%m%d') range_query = ( "SELECT trade_date, open, high, low, close, vol, amount " "FROM daily WHERE ts_code = ? AND trade_date BETWEEN ? AND ? ORDER BY trade_date" ) fallback_query = ( "SELECT trade_date, open, high, low, close, vol, amount " "FROM daily WHERE ts_code = ? ORDER BY trade_date DESC LIMIT 200" ) with db_session(read_only=True) as conn: df = pd.read_sql_query(range_query, conn, params=(ts_code, start_str, end_str)) if df.empty: df = pd.read_sql_query(fallback_query, conn, params=(ts_code,)) if df.empty: LOGGER.warning( "行情数据为空:ts_code=%s start=%s end=%s", ts_code, start, end, extra=LOG_EXTRA, ) return df df = df.sort_values('trade_date') df['trade_date'] = pd.to_datetime(df['trade_date']) df.set_index('trade_date', inplace=True) LOGGER.info("行情数据加载完成:条数=%s", len(df), extra=LOG_EXTRA) return df def _get_latest_trade_date() -> Optional[date]: try: with db_session(read_only=True) as conn: row = conn.execute( "SELECT trade_date FROM daily ORDER BY trade_date DESC LIMIT 1" ).fetchone() except Exception: # noqa: BLE001 LOGGER.exception("查询最新交易日失败", extra=LOG_EXTRA) return None if not row: return None raw_value = row["trade_date"] if not raw_value: return None try: return datetime.strptime(str(raw_value), "%Y%m%d").date() except ValueError: try: return datetime.fromisoformat(str(raw_value)).date() except ValueError: LOGGER.warning("无法解析交易日:%s", raw_value, extra=LOG_EXTRA) return None def _default_backtest_range(window_days: int = 60) -> tuple[date, date]: latest = _get_latest_trade_date() or date.today() start = latest - timedelta(days=window_days) if start > latest: start = latest return start, latest def render_today_plan() -> None: LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA) st.header("今日计划") latest_trade_date = _get_latest_trade_date() if latest_trade_date: st.caption(f"最新交易日:{latest_trade_date.isoformat()}(统计数据请见左侧系统监控)") else: st.caption("统计与决策概览现已移至左侧'系统监控'侧栏。") try: with db_session(read_only=True) as conn: date_rows = conn.execute( """ SELECT DISTINCT trade_date FROM agent_utils ORDER BY trade_date DESC LIMIT 30 """ ).fetchall() except Exception: # noqa: BLE001 LOGGER.exception("加载 agent_utils 失败", extra=LOG_EXTRA) st.warning("暂未写入部门/代理决策,请先运行回测或策略评估流程。") return trade_dates = [row["trade_date"] for row in date_rows] if not trade_dates: st.info("暂无决策记录,完成一次回测后即可在此查看部门意见与投票结果。") return # ADD: read default selection from URL q = _get_query_params() default_trade_date = q.get("date", [trade_dates[0]])[0] try: default_idx = trade_dates.index(default_trade_date) except ValueError: default_idx = 0 trade_date = st.selectbox("交易日", trade_dates, index=default_idx) with db_session(read_only=True) as conn: code_rows = conn.execute( """ SELECT DISTINCT ts_code FROM agent_utils WHERE trade_date = ? ORDER BY ts_code """, (trade_date,), ).fetchall() symbols = [row["ts_code"] for row in code_rows] if not symbols: st.info("所选交易日暂无 agent_utils 记录。") return default_ts = q.get("code", [symbols[0]])[0] try: default_ts_idx = symbols.index(default_ts) except ValueError: default_ts_idx = 0 ts_code = st.selectbox("标的", symbols, index=default_ts_idx) # ADD: batch selection for re-evaluation batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[]) # sync URL params _set_query_params(date=str(trade_date), code=str(ts_code)) with db_session(read_only=True) as conn: rows = conn.execute( """ SELECT agent, action, utils, feasible, weight FROM agent_utils WHERE trade_date = ? AND ts_code = ? ORDER BY CASE WHEN agent = 'global' THEN 1 ELSE 0 END, agent """, (trade_date, ts_code), ).fetchall() if not rows: st.info("未查询到详细决策记录,稍后再试。") return try: feasible_actions = json.loads(rows[0]["feasible"] or "[]") except (KeyError, TypeError, json.JSONDecodeError): feasible_actions = [] global_info = None dept_records: List[Dict[str, object]] = [] dept_details: Dict[str, Dict[str, object]] = {} agent_records: List[Dict[str, object]] = [] for item in rows: agent_name = item["agent"] action = item["action"] weight = float(item["weight"] or 0.0) try: utils = json.loads(item["utils"] or "{}") except json.JSONDecodeError: utils = {} if agent_name == "global": global_info = { "action": action, "confidence": float(utils.get("_confidence", 0.0)), "target_weight": float(utils.get("_target_weight", 0.0)), "department_votes": utils.get("_department_votes", {}), "requires_review": bool(utils.get("_requires_review", False)), "scope_values": utils.get("_scope_values", {}), "close_series": utils.get("_close_series", []), "turnover_series": utils.get("_turnover_series", []), "department_supplements": utils.get("_department_supplements", {}), "department_dialogue": utils.get("_department_dialogue", {}), "department_telemetry": utils.get("_department_telemetry", {}), } continue if agent_name.startswith("dept_"): code = agent_name.split("dept_", 1)[-1] signals = utils.get("_signals", []) risks = utils.get("_risks", []) supplements = utils.get("_supplements", []) dialogue = utils.get("_dialogue", []) telemetry = utils.get("_telemetry", {}) dept_records.append( { "部门": code, "行动": action, "信心": float(utils.get("_confidence", 0.0)), "权重": weight, "摘要": utils.get("_summary", ""), "核心信号": ";".join(signals) if isinstance(signals, list) else signals, "风险提示": ";".join(risks) if isinstance(risks, list) else risks, "补充次数": len(supplements) if isinstance(supplements, list) else 0, } ) dept_details[code] = { "supplements": supplements if isinstance(supplements, list) else [], "dialogue": dialogue if isinstance(dialogue, list) else [], "summary": utils.get("_summary", ""), "signals": signals, "risks": risks, "telemetry": telemetry if isinstance(telemetry, dict) else {}, } else: score_map = { key: float(val) for key, val in utils.items() if not str(key).startswith("_") } agent_records.append( { "代理": agent_name, "建议动作": action, "权重": weight, "SELL": score_map.get("SELL", 0.0), "HOLD": score_map.get("HOLD", 0.0), "BUY_S": score_map.get("BUY_S", 0.0), "BUY_M": score_map.get("BUY_M", 0.0), "BUY_L": score_map.get("BUY_L", 0.0), } ) if feasible_actions: st.caption(f"可行操作集合:{', '.join(feasible_actions)}") st.subheader("全局策略") if global_info: col1, col2, col3 = st.columns(3) col1.metric("最终行动", global_info["action"]) col2.metric("信心", f"{global_info['confidence']:.2f}") col3.metric("目标权重", f"{global_info['target_weight']:+.2%}") if global_info["department_votes"]: st.json(global_info["department_votes"]) if global_info["requires_review"]: st.warning("部门分歧较大,已标记为需人工复核。") with st.expander("基础上下文数据", expanded=False): # ADD: export buttons scope = global_info.get("scope_values") or {} close_series = global_info.get("close_series") or [] turnover_series = global_info.get("turnover_series") or [] st.write("最新字段:") if scope: st.json(scope) st.download_button( "下载字段(JSON)", data=json.dumps(scope, ensure_ascii=False, indent=2), file_name=f"{ts_code}_{trade_date}_scope.json", mime="application/json", key="dl_scope_json", ) if close_series: st.write("收盘价时间序列 (最近窗口):") st.json(close_series) try: import io, csv buf = io.StringIO() writer = csv.writer(buf) writer.writerow(["trade_date", "close"]) for dt, val in close_series: writer.writerow([dt, val]) st.download_button( "下载收盘价(CSV)", data=buf.getvalue(), file_name=f"{ts_code}_{trade_date}_close_series.csv", mime="text/csv", key="dl_close_csv", ) except Exception: pass if turnover_series: st.write("换手率时间序列 (最近窗口):") st.json(turnover_series) dept_sup = global_info.get("department_supplements") or {} dept_dialogue = global_info.get("department_dialogue") or {} dept_telemetry = global_info.get("department_telemetry") or {} if dept_sup or dept_dialogue: with st.expander("部门补数与对话记录", expanded=False): if dept_sup: st.write("补充数据:") st.json(dept_sup) if dept_dialogue: st.write("对话片段:") st.json(dept_dialogue) if dept_telemetry: with st.expander("部门 LLM 元数据", expanded=False): st.json(dept_telemetry) else: st.info("暂未写入全局策略摘要。") st.subheader("部门意见") if dept_records: # ADD: keyword filter for department summaries keyword = st.text_input("筛选摘要/信号关键词", value="") filtered = dept_records if keyword.strip(): kw = keyword.strip() filtered = [ item for item in dept_records if kw in str(item.get("摘要", "")) or kw in str(item.get("核心信号", "")) ] # ADD: confidence filter and sort min_conf = st.slider("最低信心过滤", 0.0, 1.0, 0.0, 0.05) sort_col = st.selectbox("排序列", ["信心", "权重"], index=0) filtered = [row for row in filtered if float(row.get("信心", 0.0)) >= min_conf] filtered = sorted(filtered, key=lambda r: float(r.get(sort_col, 0.0)), reverse=True) dept_df = pd.DataFrame(filtered) st.dataframe(dept_df, width='stretch', hide_index=True) try: st.download_button( "下载部门意见(CSV)", data=dept_df.to_csv(index=False), file_name=f"{trade_date}_{ts_code}_departments.csv", mime="text/csv", key="dl_dept_csv", ) except Exception: pass for code, details in dept_details.items(): with st.expander(f"{code} 补充详情", expanded=False): supplements = details.get("supplements", []) dialogue = details.get("dialogue", []) if supplements: st.write("补充数据:") st.json(supplements) else: st.caption("无补充数据请求。") if dialogue: st.write("对话记录:") for idx, line in enumerate(dialogue, start=1): st.markdown(f"**回合 {idx}:** {line}") else: st.caption("无额外对话。") telemetry = details.get("telemetry") or {} if telemetry: st.write("LLM 元数据:") st.json(telemetry) else: st.info("暂无部门记录。") st.subheader("代理评分") if agent_records: # ADD: sorting and CSV export for agents sort_agent_by = st.selectbox( "代理排序", ["权重", "SELL", "HOLD", "BUY_S", "BUY_M", "BUY_L"], index=1, ) agent_df = pd.DataFrame(agent_records) if sort_agent_by in agent_df.columns: agent_df = agent_df.sort_values(sort_agent_by, ascending=False) st.dataframe(agent_df, width='stretch', hide_index=True) try: st.download_button( "下载代理评分(CSV)", data=agent_df.to_csv(index=False), file_name=f"{trade_date}_{ts_code}_agents.csv", mime="text/csv", key="dl_agent_csv", ) except Exception: pass else: st.info("暂无基础代理评分。") st.divider() st.subheader("投资池与仓位概览") snapshot = get_latest_snapshot() if snapshot: col_a, col_b, col_c = st.columns(3) if snapshot.total_value is not None: col_a.metric("组合净值", f"{snapshot.total_value:,.2f}") if snapshot.cash is not None: col_b.metric("现金余额", f"{snapshot.cash:,.2f}") if snapshot.invested_value is not None: col_c.metric("持仓市值", f"{snapshot.invested_value:,.2f}") detail_cols = st.columns(4) if snapshot.unrealized_pnl is not None: detail_cols[0].metric("浮盈", f"{snapshot.unrealized_pnl:,.2f}") if snapshot.realized_pnl is not None: detail_cols[1].metric("已实现盈亏", f"{snapshot.realized_pnl:,.2f}") if snapshot.net_flow is not None: detail_cols[2].metric("净流入", f"{snapshot.net_flow:,.2f}") if snapshot.exposure is not None: detail_cols[3].metric("风险敞口", f"{snapshot.exposure:.2%}") if snapshot.notes: st.caption(f"备注:{snapshot.notes}") else: st.info("暂无组合快照,请在执行回测或实盘同步后写入 portfolio_snapshots。") candidates = list_investment_pool(trade_date=trade_date) if candidates: candidate_df = pd.DataFrame( [ { "交易日": item.trade_date, "代码": item.ts_code, "评分": item.score, "状态": item.status, "标签": "、".join(item.tags) if item.tags else "-", "理由": item.rationale or "", } for item in candidates ] ) st.write("候选投资池:") st.dataframe(candidate_df, width='stretch', hide_index=True) else: st.caption("候选投资池暂无数据。") positions = list_positions(active_only=False) if positions: position_df = pd.DataFrame( [ { "ID": pos.id, "代码": pos.ts_code, "开仓日": pos.opened_date, "平仓日": pos.closed_date or "-", "状态": pos.status, "数量": pos.quantity, "成本": pos.cost_price, "现价": pos.market_price, "市值": pos.market_value, "浮盈": pos.unrealized_pnl, "已实现": pos.realized_pnl, "目标权重": pos.target_weight, } for pos in positions ] ) st.write("组合持仓:") st.dataframe(position_df, width='stretch', hide_index=True) else: st.caption("组合持仓暂无记录。") trades = list_recent_trades(limit=20) if trades: trades_df = pd.DataFrame(trades) st.write("近期成交:") st.dataframe(trades_df, width='stretch', hide_index=True) else: st.caption("近期成交暂无记录。") st.caption("数据来源:agent_utils、investment_pool、portfolio_positions、portfolio_trades、portfolio_snapshots。") st.divider() st.subheader("策略重评估") st.caption("对当前选中的交易日与标的,立即触发一次策略评估并回写 agent_utils。") cols_re = st.columns([1,1]) if cols_re[0].button("对该标的重评估", key="reevaluate_current_symbol"): with st.spinner("正在重评估..."): try: trade_date_obj = None try: trade_date_obj = date.fromisoformat(str(trade_date)) except Exception: try: trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() except Exception: pass if trade_date_obj is None: raise ValueError(f"无法解析交易日:{trade_date}") # snapshot before with db_session(read_only=True) as conn: before_rows = conn.execute( """ SELECT agent, action, utils FROM agent_utils WHERE trade_date = ? AND ts_code = ? """, (trade_date, ts_code), ).fetchall() before_map = {row["agent"]: (row["action"], row["utils"]) for row in before_rows} cfg = BtConfig( id="reeval_ui", name="UI Re-evaluation", start_date=trade_date_obj, end_date=trade_date_obj, universe=[ts_code], params={}, ) engine = BacktestEngine(cfg) state = PortfolioState() _ = engine.simulate_day(trade_date_obj, state) # compare after with db_session(read_only=True) as conn: after_rows = conn.execute( """ SELECT agent, action, utils FROM agent_utils WHERE trade_date = ? AND ts_code = ? """, (trade_date, ts_code), ).fetchall() changes = [] for row in after_rows: agent = row["agent"] new_action = row["action"] old_action, _old_utils = before_map.get(agent, (None, None)) if new_action != old_action: changes.append({"代理": agent, "原动作": old_action, "新动作": new_action}) if changes: st.success("重评估完成,检测到动作变更:") st.dataframe(pd.DataFrame(changes), hide_index=True, width='stretch') else: st.success("重评估完成,无动作变更。") st.rerun() except Exception as exc: # noqa: BLE001 LOGGER.exception("重评估失败", extra=LOG_EXTRA) st.error(f"重评估失败:{exc}") if cols_re[1].button("批量重评估(所选)", key="reevaluate_batch", disabled=not batch_symbols): with st.spinner("批量重评估执行中..."): try: trade_date_obj = None try: trade_date_obj = date.fromisoformat(str(trade_date)) except Exception: try: trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() except Exception: pass if trade_date_obj is None: raise ValueError(f"无法解析交易日:{trade_date}") progress = st.progress(0.0) changes_all: List[Dict[str, object]] = [] for idx, code in enumerate(batch_symbols, start=1): with db_session(read_only=True) as conn: before_rows = conn.execute( "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", (trade_date, code), ).fetchall() before_map = {row["agent"]: row["action"] for row in before_rows} cfg = BtConfig( id="reeval_ui_batch", name="UI Batch Re-eval", start_date=trade_date_obj, end_date=trade_date_obj, universe=[code], params={}, ) engine = BacktestEngine(cfg) state = PortfolioState() _ = engine.simulate_day(trade_date_obj, state) with db_session(read_only=True) as conn: after_rows = conn.execute( "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", (trade_date, code), ).fetchall() for row in after_rows: agent = row["agent"] new_action = row["action"] old_action = before_map.get(agent) if new_action != old_action: changes_all.append({"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action}) progress.progress(idx / max(1, len(batch_symbols))) st.success("批量重评估完成。") if changes_all: st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch') st.rerun() except Exception as exc: # noqa: BLE001 LOGGER.exception("批量重评估失败", extra=LOG_EXTRA) st.error(f"批量重评估失败:{exc}") def render_backtest() -> None: LOGGER.info("渲染回测页面", extra=LOG_EXTRA) st.header("回测与复盘") st.write("在此运行回测、展示净值曲线与代理贡献。") cfg = get_config() default_start, default_end = _default_backtest_range(window_days=60) LOGGER.debug( "回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s", default_start, default_end, "000001.SZ", 0.035, -0.015, 10, extra=LOG_EXTRA, ) col1, col2 = st.columns(2) start_date = col1.date_input("开始日期", value=default_start) end_date = col2.date_input("结束日期", value=default_end) universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ") target = st.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f") stop = st.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f") hold_days = st.number_input("持有期(交易日)", value=10, step=1) LOGGER.debug( "当前回测表单输入:start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s", start_date, end_date, universe_text, target, stop, hold_days, extra=LOG_EXTRA, ) if st.button("运行回测"): LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA) decision_log_container = st.container() status_box = st.status("准备执行回测...", expanded=True) llm_stats_placeholder = st.empty() decision_entries: List[str] = [] def _decision_callback(ts_code: str, trade_dt: date, ctx: AgentContext, decision: Decision) -> None: ts_label = trade_dt.isoformat() summary = "" for dept_decision in decision.department_decisions.values(): if getattr(dept_decision, "summary", ""): summary = str(dept_decision.summary) break entry_lines = [ f"**{ts_label} {ts_code}** → {decision.action.value} (信心 {decision.confidence:.2f})", ] if summary: entry_lines.append(f"摘要:{summary}") dep_highlights = [] for dept_code, dept_decision in decision.department_decisions.items(): dep_highlights.append( f"{dept_code}:{dept_decision.action.value}({dept_decision.confidence:.2f})" ) if dep_highlights: entry_lines.append("部门意见:" + ";".join(dep_highlights)) decision_entries.append(" \n".join(entry_lines)) decision_log_container.markdown("\n\n".join(decision_entries[-200:])) status_box.write(f"{ts_label} {ts_code} → {decision.action.value} (信心 {decision.confidence:.2f})") stats = snapshot_llm_metrics() llm_stats_placeholder.json( { "LLM 调用次数": stats.get("total_calls", 0), "Prompt Tokens": stats.get("total_prompt_tokens", 0), "Completion Tokens": stats.get("total_completion_tokens", 0), "按 Provider": stats.get("provider_calls", {}), "按模型": stats.get("model_calls", {}), } ) _update_dashboard_sidebar(stats) reset_llm_metrics() status_box.update(label="执行回测中...", state="running") try: universe = [code.strip() for code in universe_text.split(',') if code.strip()] LOGGER.info( "回测参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s", start_date, end_date, universe, target, stop, hold_days, extra=LOG_EXTRA, ) cfg = BtConfig( id="streamlit_demo", name="Streamlit Demo Strategy", start_date=start_date, end_date=end_date, universe=universe, params={ "target": target, "stop": stop, "hold_days": int(hold_days), }, ) result = run_backtest(cfg, decision_callback=_decision_callback) LOGGER.info( "回测完成:nav_records=%s trades=%s", len(result.nav_series), len(result.trades), extra=LOG_EXTRA, ) status_box.update(label="回测执行完成", state="complete") st.success("回测执行完成,详见下方结果与统计。") metrics = snapshot_llm_metrics() llm_stats_placeholder.json( { "LLM 调用次数": metrics.get("total_calls", 0), "Prompt Tokens": metrics.get("total_prompt_tokens", 0), "Completion Tokens": metrics.get("total_completion_tokens", 0), "按 Provider": metrics.get("provider_calls", {}), "按模型": metrics.get("model_calls", {}), } ) _update_dashboard_sidebar(metrics) st.json({"nav_records": result.nav_series, "trades": result.trades}) except Exception as exc: # noqa: BLE001 LOGGER.exception("回测执行失败", extra=LOG_EXTRA) status_box.update(label="回测执行失败", state="error") st.error(f"回测执行失败:{exc}") with st.expander("离线调参实验 (DecisionEnv)", expanded=False): st.caption( "使用 DecisionEnv 对代理权重做离线调参。请选择需要优化的代理并设定权重范围," "系统会运行一次回测并返回收益、回撤等指标。若 LLM 网络不可用,将返回失败标记。" ) disable_departments = st.checkbox( "禁用部门 LLM(仅规则代理,适合离线快速评估)", value=True, help="关闭部门调用后不依赖外部 LLM 网络,仅根据规则代理权重模拟。", ) default_experiment_id = f"streamlit_{datetime.now().strftime('%Y%m%d_%H%M%S')}" experiment_id = st.text_input( "实验 ID", value=default_experiment_id, help="用于在 tuning_results 表中区分不同实验。", ) strategy_label = st.text_input( "策略说明", value="DecisionEnv", help="可选:为本次调参记录一个策略名称或备注。", ) agent_objects = default_agents() agent_names = [agent.name for agent in agent_objects] if not agent_names: st.info("暂无可调整的代理。") else: selected_agents = st.multiselect( "选择调参的代理权重", agent_names, default=agent_names[:2], key="decision_env_agents", ) specs: List[ParameterSpec] = [] action_values: List[float] = [] range_valid = True for idx, agent_name in enumerate(selected_agents): col_min, col_max, col_action = st.columns([1, 1, 2]) min_key = f"decision_env_min_{agent_name}" max_key = f"decision_env_max_{agent_name}" action_key = f"decision_env_action_{agent_name}" default_min = 0.0 default_max = 1.0 min_val = col_min.number_input( f"{agent_name} 最小权重", min_value=0.0, max_value=1.0, value=default_min, step=0.05, key=min_key, ) max_val = col_max.number_input( f"{agent_name} 最大权重", min_value=0.0, max_value=1.0, value=default_max, step=0.05, key=max_key, ) if max_val <= min_val: range_valid = False action_val = col_action.slider( f"{agent_name} 动作 (0-1)", min_value=0.0, max_value=1.0, value=0.5, step=0.01, key=action_key, ) specs.append( ParameterSpec( name=f"weight_{agent_name}", target=f"agent_weights.{agent_name}", minimum=min_val, maximum=max_val, ) ) action_values.append(action_val) run_decision_env = st.button("执行单次调参", key="run_decision_env_button") just_finished_single = False if run_decision_env: if not selected_agents: st.warning("请至少选择一个代理进行调参。") elif not range_valid: st.error("请确保所有代理的最大权重大于最小权重。") else: LOGGER.info( "离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s", selected_agents, action_values, disable_departments, extra=LOG_EXTRA, ) baseline_weights = cfg.agent_weights.as_dict() for agent in agent_objects: baseline_weights.setdefault(agent.name, 1.0) universe_env = [code.strip() for code in universe_text.split(',') if code.strip()] if not universe_env: st.error("请先指定至少一个股票代码。") else: bt_cfg_env = BtConfig( id="decision_env_streamlit", name="DecisionEnv Streamlit", start_date=start_date, end_date=end_date, universe=universe_env, params={ "target": target, "stop": stop, "hold_days": int(hold_days), }, method=cfg.decision_method, ) env = DecisionEnv( bt_config=bt_cfg_env, parameter_specs=specs, baseline_weights=baseline_weights, disable_departments=disable_departments, ) env.reset() LOGGER.debug( "离线调参(单次)启动 DecisionEnv:cfg=%s 参数维度=%s", bt_cfg_env, len(specs), extra=LOG_EXTRA, ) with st.spinner("正在执行离线调参……"): try: observation, reward, done, info = env.step(action_values) LOGGER.info( "离线调参(单次)完成,obs=%s reward=%.4f done=%s", observation, reward, done, extra=LOG_EXTRA, ) except Exception as exc: # noqa: BLE001 LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA) st.error(f"离线调参失败:{exc}") st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None) else: if observation.get("failure"): st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。") st.json(observation) st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None) else: resolved_experiment_id = experiment_id or str(uuid.uuid4()) resolved_strategy = strategy_label or "DecisionEnv" action_payload = { name: value for name, value in zip(selected_agents, action_values) } metrics_payload = dict(observation) metrics_payload["reward"] = reward log_success = False try: log_tuning_result( experiment_id=resolved_experiment_id, strategy=resolved_strategy, action=action_payload, reward=reward, metrics=metrics_payload, weights=info.get("weights", {}), ) except Exception: # noqa: BLE001 LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA) else: log_success = True LOGGER.info( "离线调参(单次)日志写入成功:experiment=%s strategy=%s", resolved_experiment_id, resolved_strategy, extra=LOG_EXTRA, ) st.session_state[_DECISION_ENV_SINGLE_RESULT_KEY] = { "observation": dict(observation), "reward": float(reward), "weights": info.get("weights", {}), "nav_series": info.get("nav_series"), "trades": info.get("trades"), "selected_agents": list(selected_agents), "action_values": list(action_values), "experiment_id": resolved_experiment_id, "strategy_label": resolved_strategy, "logged": log_success, } just_finished_single = True single_result = st.session_state.get(_DECISION_ENV_SINGLE_RESULT_KEY) if single_result: if just_finished_single: st.success("离线调参完成") else: st.success("离线调参结果(最近一次运行)") st.caption( f"实验 ID:{single_result.get('experiment_id', '-') } | 策略:{single_result.get('strategy_label', 'DecisionEnv')}" ) observation = single_result.get("observation", {}) reward = float(single_result.get("reward", 0.0)) col_metrics = st.columns(4) col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}") col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}") col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}") col_metrics[3].metric("奖励", f"{reward:+.4f}") weights_dict = single_result.get("weights") or {} if weights_dict: st.write("调参后权重:") st.json(weights_dict) if st.button("保存这些权重为默认配置", key="save_decision_env_weights_single"): try: cfg.agent_weights.update_from_dict(weights_dict) save_config(cfg) except Exception as exc: # noqa: BLE001 LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)}) st.error(f"写入配置失败:{exc}") else: st.success("代理权重已写入 config.json") if single_result.get("logged"): st.caption("调参结果已写入 tuning_results 表。") nav_series = single_result.get("nav_series") or [] if nav_series: try: nav_df = pd.DataFrame(nav_series) if {"trade_date", "nav"}.issubset(nav_df.columns): nav_df = nav_df.sort_values("trade_date") nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"]) st.line_chart(nav_df.set_index("trade_date")["nav"], height=220) except Exception: # noqa: BLE001 LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA) trades = single_result.get("trades") or [] if trades: st.write("成交记录:") st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch') if st.button("清除单次调参结果", key="clear_decision_env_single"): st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None) st.success("已清除单次调参结果缓存。") st.divider() st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。") default_grid = "\n".join( [ ",".join(["0.2" for _ in specs]), ",".join(["0.5" for _ in specs]), ",".join(["0.8" for _ in specs]), ] ) if specs else "" action_grid_raw = st.text_area( "动作列表", value=default_grid, height=120, key="decision_env_batch_actions", ) run_batch = st.button("批量执行调参", key="run_decision_env_batch") batch_just_ran = False if run_batch: if not selected_agents: st.warning("请先选择调参代理。") elif not range_valid: st.error("请确保所有代理的最大权重大于最小权重。") else: LOGGER.info( "离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s", selected_agents, disable_departments, extra=LOG_EXTRA, ) lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()] if not lines: st.warning("请在文本框中输入至少一组动作。") else: LOGGER.debug( "离线调参(批量)原始输入=%s", lines, extra=LOG_EXTRA, ) parsed_actions: List[List[float]] = [] for line in lines: try: values = [float(val.strip()) for val in line.split(',') if val.strip()] except ValueError: st.error(f"无法解析动作行:{line}") parsed_actions = [] break if len(values) != len(specs): st.error(f"动作维度不匹配(期望 {len(specs)} 个值):{line}") parsed_actions = [] break parsed_actions.append(values) if parsed_actions: LOGGER.info( "离线调参(批量)解析动作成功,数量=%s", len(parsed_actions), extra=LOG_EXTRA, ) baseline_weights = cfg.agent_weights.as_dict() for agent in agent_objects: baseline_weights.setdefault(agent.name, 1.0) universe_env = [code.strip() for code in universe_text.split(',') if code.strip()] if not universe_env: st.error("请先指定至少一个股票代码。") else: bt_cfg_env = BtConfig( id="decision_env_streamlit_batch", name="DecisionEnv Batch", start_date=start_date, end_date=end_date, universe=universe_env, params={ "target": target, "stop": stop, "hold_days": int(hold_days), }, method=cfg.decision_method, ) env = DecisionEnv( bt_config=bt_cfg_env, parameter_specs=specs, baseline_weights=baseline_weights, disable_departments=disable_departments, ) results: List[Dict[str, object]] = [] resolved_experiment_id = experiment_id or str(uuid.uuid4()) resolved_strategy = strategy_label or "DecisionEnv" LOGGER.debug( "离线调参(批量)启动 DecisionEnv:cfg=%s 动作组=%s", bt_cfg_env, len(parsed_actions), extra=LOG_EXTRA, ) with st.spinner("正在批量执行调参……"): for idx, action_vals in enumerate(parsed_actions, start=1): env.reset() try: observation, reward, done, info = env.step(action_vals) except Exception as exc: # noqa: BLE001 LOGGER.exception("批量调参失败", extra=LOG_EXTRA) results.append( { "序号": idx, "动作": action_vals, "状态": "error", "错误": str(exc), } ) continue if observation.get("failure"): results.append( { "序号": idx, "动作": action_vals, "状态": "failure", "奖励": -1.0, } ) else: LOGGER.info( "离线调参(批量)第 %s 组完成,reward=%.4f obs=%s", idx, reward, observation, extra=LOG_EXTRA, ) action_payload = { name: value for name, value in zip(selected_agents, action_vals) } metrics_payload = dict(observation) metrics_payload["reward"] = reward weights_payload = info.get("weights", {}) try: log_tuning_result( experiment_id=resolved_experiment_id, strategy=resolved_strategy, action=action_payload, reward=reward, metrics=metrics_payload, weights=weights_payload, ) except Exception: # noqa: BLE001 LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA) results.append( { "序号": idx, "动作": action_vals, "状态": "ok", "总收益": observation.get("total_return", 0.0), "最大回撤": observation.get("max_drawdown", 0.0), "波动率": observation.get("volatility", 0.0), "奖励": reward, "权重": weights_payload, } ) st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = { "results": results, "selectable": [ row for row in results if row.get("状态") == "ok" and row.get("权重") ], "experiment_id": resolved_experiment_id, "strategy_label": resolved_strategy, } batch_just_ran = True LOGGER.info( "离线调参(批量)执行结束,总结果条数=%s", len(results), extra=LOG_EXTRA, ) batch_state = st.session_state.get(_DECISION_ENV_BATCH_RESULTS_KEY) if batch_state: results = batch_state.get("results") or [] if results: if batch_just_ran: st.success("批量调参完成") else: st.success("批量调参结果(最近一次运行)") st.caption( f"实验 ID:{batch_state.get('experiment_id', '-') } | 策略:{batch_state.get('strategy_label', 'DecisionEnv')}" ) results_df = pd.DataFrame(results) st.write("批量调参结果:") st.dataframe(results_df, hide_index=True, width='stretch') selectable = batch_state.get("selectable") or [] if selectable: option_labels = [ f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}" for row in selectable ] selected_label = st.selectbox( "选择要保存的记录", option_labels, key="decision_env_batch_select", ) selected_row = None for label, row in zip(option_labels, selectable): if label == selected_label: selected_row = row break if selected_row and st.button( "保存所选权重为默认配置", key="save_decision_env_weights_batch", ): try: cfg.agent_weights.update_from_dict(selected_row.get("权重", {})) save_config(cfg) except Exception as exc: # noqa: BLE001 LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)}) st.error(f"写入配置失败:{exc}") else: st.success( f"已将序号 {selected_row['序号']} 的权重写入 config.json" ) else: st.caption("暂无成功的结果可供保存。") else: st.caption("批量调参在最近一次执行中未产生结果。") if st.button("清除批量调参结果", key="clear_decision_env_batch"): st.session_state.pop(_DECISION_ENV_BATCH_RESULTS_KEY, None) st.session_state.pop("decision_env_batch_select", None) st.success("已清除批量调参结果缓存。") # ADD: Comparison view for multiple backtest configurations with st.expander("回测结果对比", expanded=False): st.caption("从历史回测配置中选择多个进行净值曲线与指标对比。") normalize_to_one = st.checkbox("归一化到 1 起点", value=True) use_log_y = st.checkbox("对数坐标", value=False) metric_options = ["总收益", "最大回撤", "交易数", "平均换手", "风险事件"] selected_metrics = st.multiselect("显示指标列", metric_options, default=metric_options) try: with db_session(read_only=True) as conn: cfg_rows = conn.execute( "SELECT id, name FROM bt_config ORDER BY rowid DESC LIMIT 50" ).fetchall() except Exception: # noqa: BLE001 LOGGER.exception("读取 bt_config 失败", extra=LOG_EXTRA) cfg_rows = [] cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows] selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2]) selected_ids = [label.split(" | ")[0].strip() for label in selected_labels] nav_df = pd.DataFrame() rpt_df = pd.DataFrame() if selected_ids: try: with db_session(read_only=True) as conn: nav_df = pd.read_sql_query( "SELECT cfg_id, trade_date, nav FROM bt_nav WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))), conn, params=tuple(selected_ids), ) rpt_df = pd.read_sql_query( "SELECT cfg_id, summary FROM bt_report WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))), conn, params=tuple(selected_ids), ) except Exception: # noqa: BLE001 LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA) st.error("读取回测结果失败") nav_df = pd.DataFrame() rpt_df = pd.DataFrame() if not nav_df.empty: try: nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce") # ADD: date window filter overall_min = pd.to_datetime(nav_df["trade_date"].min()).date() overall_max = pd.to_datetime(nav_df["trade_date"].max()).date() col_d1, col_d2 = st.columns(2) start_filter = col_d1.date_input("起始日期", value=overall_min) end_filter = col_d2.date_input("结束日期", value=overall_max) if start_filter > end_filter: start_filter, end_filter = end_filter, start_filter mask = (nav_df["trade_date"].dt.date >= start_filter) & (nav_df["trade_date"].dt.date <= end_filter) nav_df = nav_df.loc[mask] pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav") if normalize_to_one: pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s) import plotly.graph_objects as go fig = go.Figure() for col in pivot.columns: fig.add_trace(go.Scatter(x=pivot.index, y=pivot[col], mode="lines", name=str(col))) fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10)) if use_log_y: fig.update_yaxes(type="log") st.plotly_chart(fig, use_container_width=True) # ADD: export pivot try: csv_buf = pivot.reset_index() csv_buf.columns = ["trade_date"] + [str(c) for c in pivot.columns] st.download_button( "下载曲线(CSV)", data=csv_buf.to_csv(index=False), file_name="bt_nav_compare.csv", mime="text/csv", key="dl_nav_compare", ) except Exception: pass except Exception: # noqa: BLE001 LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA) if not rpt_df.empty: try: metrics_rows: List[Dict[str, object]] = [] for _, row in rpt_df.iterrows(): cfg_id = row["cfg_id"] try: summary = json.loads(row["summary"]) if isinstance(row["summary"], str) else (row["summary"] or {}) except json.JSONDecodeError: summary = {} record = { "cfg_id": cfg_id, "总收益": summary.get("total_return"), "最大回撤": summary.get("max_drawdown"), "交易数": summary.get("trade_count"), "平均换手": summary.get("avg_turnover"), "风险事件": summary.get("risk_events"), } metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)}) if metrics_rows: dfm = pd.DataFrame(metrics_rows) st.dataframe(dfm, hide_index=True, width='stretch') try: st.download_button( "下载指标(CSV)", data=dfm.to_csv(index=False), file_name="bt_metrics_compare.csv", mime="text/csv", key="dl_metrics_compare", ) except Exception: pass except Exception: # noqa: BLE001 LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA) else: st.info("请选择至少一个配置进行对比。") def render_settings() -> None: LOGGER.info("渲染设置页面", extra=LOG_EXTRA) st.header("数据与设置") cfg = get_config() LOGGER.debug("当前 TuShare Token 是否已配置=%s", bool(cfg.tushare_token), extra=LOG_EXTRA) token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password") if st.button("保存设置"): LOGGER.info("保存设置按钮被点击", extra=LOG_EXTRA) cfg.tushare_token = token.strip() or None LOGGER.info("TuShare Token 更新,是否为空=%s", cfg.tushare_token is None, extra=LOG_EXTRA) save_config() st.success("设置已保存,仅在当前会话生效。") st.write("新闻源开关与数据库备份将在此配置。") st.divider() st.subheader("LLM 设置") providers = cfg.llm_providers provider_keys = sorted(providers.keys()) st.caption("先在 Provider 中维护基础连接(URL、Key、模型),再为全局与各部门设置个性化参数。") # Provider management ------------------------------------------------- provider_select_col, provider_manage_col = st.columns([3, 1]) if provider_keys: try: default_provider = cfg.llm.primary.provider or provider_keys[0] provider_index = provider_keys.index(default_provider) except ValueError: provider_index = 0 selected_provider = provider_select_col.selectbox( "选择 Provider", provider_keys, index=provider_index, key="llm_provider_select", ) else: selected_provider = None provider_select_col.info("尚未配置 Provider,请先创建。") new_provider_name = provider_manage_col.text_input("新增 Provider", key="new_provider_name") if provider_manage_col.button("创建 Provider", key="create_provider_btn"): key = (new_provider_name or "").strip().lower() if not key: st.warning("请输入有效的 Provider 名称。") elif key in providers: st.warning(f"Provider {key} 已存在。") else: providers[key] = LLMProvider(key=key) cfg.llm_providers = providers save_config() st.success(f"已创建 Provider {key}。") st.rerun() if selected_provider: provider_cfg = providers.get(selected_provider, LLMProvider(key=selected_provider)) title_key = f"provider_title_{selected_provider}" base_key = f"provider_base_{selected_provider}" api_key_key = f"provider_api_{selected_provider}" default_model_key = f"provider_default_model_{selected_provider}" mode_key = f"provider_mode_{selected_provider}" temp_key = f"provider_temp_{selected_provider}" timeout_key = f"provider_timeout_{selected_provider}" prompt_key = f"provider_prompt_{selected_provider}" enabled_key = f"provider_enabled_{selected_provider}" title_val = st.text_input("备注名称", value=provider_cfg.title or "", key=title_key) base_val = st.text_input("Base URL", value=provider_cfg.base_url or "", key=base_key, help="调用地址,例如:https://api.openai.com") api_val = st.text_input("API Key", value=provider_cfg.api_key or "", key=api_key_key, type="password") st.markdown("可用模型:") if provider_cfg.models: st.code("\n".join(provider_cfg.models), language="text") else: st.info("尚未获取模型列表,可点击下方按钮自动拉取。") # ADD: show cache last updated if available try: cache_key = f"{selected_provider}|{(base_val or '').strip()}" entry = _MODEL_CACHE.get(cache_key) if entry and isinstance(entry.get("ts"), float): ts = datetime.fromtimestamp(entry["ts"]).strftime("%Y-%m-%d %H:%M:%S") st.caption(f"最近拉取时间:{ts}") except Exception: pass fetch_key = f"fetch_models_{selected_provider}" if st.button("获取模型列表", key=fetch_key): with st.spinner("正在获取模型列表..."): models, error = _discover_provider_models(provider_cfg, base_val, api_val) if error: st.error(error) else: provider_cfg.models = models if models and (not provider_cfg.default_model or provider_cfg.default_model not in models): provider_cfg.default_model = models[0] providers[selected_provider] = provider_cfg cfg.llm_providers = providers cfg.sync_runtime_llm() save_config() st.success(f"共获取 {len(models)} 个模型。") st.rerun() if st.button("保存 Provider", key=f"save_provider_{selected_provider}"): provider_cfg.title = title_val.strip() provider_cfg.base_url = base_val.strip() provider_cfg.api_key = api_val.strip() or None if provider_cfg.models and default_model_val in provider_cfg.models: provider_cfg.default_model = default_model_val else: provider_cfg.default_model = default_model_val provider_cfg.default_temperature = float(temp_val) provider_cfg.default_timeout = float(timeout_val) provider_cfg.prompt_template = prompt_template_val.strip() provider_cfg.enabled = enabled_val provider_cfg.mode = mode_val providers[selected_provider] = provider_cfg cfg.llm_providers = providers cfg.sync_runtime_llm() save_config() st.success("Provider 已保存。") st.session_state[title_key] = provider_cfg.title or "" st.session_state[default_model_key] = provider_cfg.default_model or "" provider_in_use = (cfg.llm.primary.provider == selected_provider) or any( ep.provider == selected_provider for ep in cfg.llm.ensemble ) if not provider_in_use: for dept in cfg.departments.values(): if dept.llm.primary.provider == selected_provider or any(ep.provider == selected_provider for ep in dept.llm.ensemble): provider_in_use = True break if st.button( "删除 Provider", key=f"delete_provider_{selected_provider}", disabled=provider_in_use or len(providers) <= 1, ): providers.pop(selected_provider, None) cfg.llm_providers = providers cfg.sync_runtime_llm() save_config() st.success("Provider 已删除。") st.rerun() st.markdown("##### 全局推理配置") if not provider_keys: st.warning("请先配置至少一个 Provider。") else: global_cfg = cfg.llm primary = global_cfg.primary try: provider_index = provider_keys.index(primary.provider or provider_keys[0]) except ValueError: provider_index = 0 selected_global_provider = st.selectbox( "主模型 Provider", provider_keys, index=provider_index, key="global_provider_select", ) provider_cfg = providers.get(selected_global_provider) available_models = provider_cfg.models if provider_cfg else [] default_model = primary.model or (provider_cfg.default_model if provider_cfg else None) if available_models: options = available_models + ["自定义"] try: model_index = available_models.index(default_model) model_choice = st.selectbox("主模型", options, index=model_index, key="global_model_choice") except ValueError: model_choice = st.selectbox("主模型", options, index=len(options) - 1, key="global_model_choice") if model_choice == "自定义": model_val = st.text_input("自定义模型", value=default_model or "", key="global_model_custom").strip() else: model_val = model_choice else: model_val = st.text_input("主模型", value=default_model or "", key="global_model_custom").strip() temp_default = primary.temperature if primary.temperature is not None else (provider_cfg.default_temperature if provider_cfg else 0.2) temp_val = st.slider("主模型温度", min_value=0.0, max_value=2.0, value=float(temp_default), step=0.05, key="global_temp") timeout_default = primary.timeout if primary.timeout is not None else (provider_cfg.default_timeout if provider_cfg else 30.0) timeout_val = st.number_input("主模型超时(秒)", min_value=5, max_value=300, value=int(timeout_default), step=5, key="global_timeout") prompt_template_val = st.text_area( "主模型 Prompt 模板(可选)", value=primary.prompt_template or provider_cfg.prompt_template if provider_cfg else "", height=120, key="global_prompt_template", ) strategy_val = st.selectbox("推理策略", sorted(ALLOWED_LLM_STRATEGIES), index=sorted(ALLOWED_LLM_STRATEGIES).index(global_cfg.strategy) if global_cfg.strategy in ALLOWED_LLM_STRATEGIES else 0, key="global_strategy") show_ensemble = strategy_val != "single" majority_threshold_val = st.number_input( "多数投票门槛", min_value=1, max_value=10, value=int(global_cfg.majority_threshold), step=1, key="global_majority", disabled=not show_ensemble, ) if not show_ensemble: majority_threshold_val = 1 ensemble_rows: List[Dict[str, str]] = [] if show_ensemble: ensemble_rows = [ { "provider": ep.provider, "model": ep.model or "", "temperature": "" if ep.temperature is None else f"{ep.temperature:.3f}", "timeout": "" if ep.timeout is None else str(int(ep.timeout)), "prompt_template": ep.prompt_template or "", } for ep in global_cfg.ensemble ] or [{"provider": primary.provider or selected_global_provider, "model": "", "temperature": "", "timeout": "", "prompt_template": ""}] ensemble_editor = st.data_editor( ensemble_rows, num_rows="dynamic", key="global_ensemble_editor", width='stretch', hide_index=True, column_config={ "provider": st.column_config.SelectboxColumn("Provider", options=provider_keys), "model": st.column_config.TextColumn("模型"), "temperature": st.column_config.TextColumn("温度"), "timeout": st.column_config.TextColumn("超时(秒)"), "prompt_template": st.column_config.TextColumn("Prompt 模板"), }, ) if hasattr(ensemble_editor, "to_dict"): ensemble_rows = ensemble_editor.to_dict("records") else: ensemble_rows = ensemble_editor else: st.info("当前策略为单模型,未启用协作模型。") if st.button("保存全局配置", key="save_global_llm"): primary.provider = selected_global_provider primary.model = model_val or None primary.temperature = float(temp_val) primary.timeout = float(timeout_val) primary.prompt_template = prompt_template_val.strip() or None primary.base_url = None primary.api_key = None new_ensemble: List[LLMEndpoint] = [] if show_ensemble: for row in ensemble_rows: provider_val = (row.get("provider") or "").strip().lower() if not provider_val: continue model_raw = (row.get("model") or "").strip() or None temp_raw = (row.get("temperature") or "").strip() timeout_raw = (row.get("timeout") or "").strip() prompt_raw = (row.get("prompt_template") or "").strip() new_ensemble.append( LLMEndpoint( provider=provider_val, model=model_raw, temperature=float(temp_raw) if temp_raw else None, timeout=float(timeout_raw) if timeout_raw else None, prompt_template=prompt_raw or None, ) ) cfg.llm.ensemble = new_ensemble cfg.llm.strategy = strategy_val cfg.llm.majority_threshold = int(majority_threshold_val) cfg.sync_runtime_llm() save_config() st.success("全局 LLM 配置已保存。") st.json(llm_config_snapshot()) # Department configuration ------------------------------------------- st.markdown("##### 部门配置") dept_settings = cfg.departments or {} dept_rows = [ { "code": code, "title": dept.title, "description": dept.description, "weight": float(dept.weight), "strategy": dept.llm.strategy, "majority_threshold": dept.llm.majority_threshold, "provider": dept.llm.primary.provider or (provider_keys[0] if provider_keys else ""), "model": dept.llm.primary.model or "", "temperature": "" if dept.llm.primary.temperature is None else f"{dept.llm.primary.temperature:.3f}", "timeout": "" if dept.llm.primary.timeout is None else str(int(dept.llm.primary.timeout)), "prompt_template": dept.llm.primary.prompt_template or "", } for code, dept in sorted(dept_settings.items()) ] if not dept_rows: st.info("当前未配置部门,可在 config.json 中添加。") dept_rows = [] dept_editor = st.data_editor( dept_rows, num_rows="fixed", key="department_editor", width='stretch', hide_index=True, column_config={ "code": st.column_config.TextColumn("编码", disabled=True), "title": st.column_config.TextColumn("名称"), "description": st.column_config.TextColumn("说明"), "weight": st.column_config.NumberColumn("权重", min_value=0.0, max_value=10.0, step=0.1), "strategy": st.column_config.SelectboxColumn("策略", options=sorted(ALLOWED_LLM_STRATEGIES)), "majority_threshold": st.column_config.NumberColumn("投票阈值", min_value=1, max_value=10, step=1), "provider": st.column_config.SelectboxColumn("Provider", options=provider_keys or [""]), "model": st.column_config.TextColumn("模型"), "temperature": st.column_config.TextColumn("温度"), "timeout": st.column_config.TextColumn("超时(秒)"), "prompt_template": st.column_config.TextColumn("Prompt 模板"), }, ) if hasattr(dept_editor, "to_dict"): dept_rows = dept_editor.to_dict("records") else: dept_rows = dept_editor col_reset, col_save = st.columns([1, 1]) if col_save.button("保存部门配置"): updated_departments: Dict[str, DepartmentSettings] = {} for row in dept_rows: code = row.get("code") if not code: continue existing = dept_settings.get(code) or DepartmentSettings(code=code, title=code) existing.title = row.get("title") or existing.title existing.description = row.get("description") or "" try: existing.weight = max(0.0, float(row.get("weight", existing.weight))) except (TypeError, ValueError): pass strategy_val = (row.get("strategy") or existing.llm.strategy).lower() if strategy_val in ALLOWED_LLM_STRATEGIES: existing.llm.strategy = strategy_val if existing.llm.strategy == "single": existing.llm.majority_threshold = 1 existing.llm.ensemble = [] else: majority_raw = row.get("majority_threshold") try: majority_val = int(majority_raw) if majority_val > 0: existing.llm.majority_threshold = majority_val except (TypeError, ValueError): pass provider_val = (row.get("provider") or existing.llm.primary.provider or (provider_keys[0] if provider_keys else "ollama")).strip().lower() model_val = (row.get("model") or "").strip() or None temp_raw = (row.get("temperature") or "").strip() timeout_raw = (row.get("timeout") or "").strip() prompt_raw = (row.get("prompt_template") or "").strip() endpoint = existing.llm.primary or LLMEndpoint() endpoint.provider = provider_val endpoint.model = model_val endpoint.temperature = float(temp_raw) if temp_raw else None endpoint.timeout = float(timeout_raw) if timeout_raw else None endpoint.prompt_template = prompt_raw or None endpoint.base_url = None endpoint.api_key = None existing.llm.primary = endpoint if existing.llm.strategy != "single": existing.llm.ensemble = [] updated_departments[code] = existing if updated_departments: cfg.departments = updated_departments cfg.sync_runtime_llm() save_config() st.success("部门配置已更新。") else: st.warning("未能解析部门配置输入。") if col_reset.button("恢复默认部门"): from app.utils.config import _default_departments cfg.departments = _default_departments() cfg.sync_runtime_llm() save_config() st.success("已恢复默认部门配置。") st.rerun() st.caption("部门配置存储为独立 LLM 参数,执行时会自动套用对应 Provider 的连接信息。") def render_tests() -> None: LOGGER.info("渲染自检页面", extra=LOG_EXTRA) st.header("自检测试") st.write("用于快速检查数据库与数据拉取是否正常工作。") if st.button("测试数据库初始化"): LOGGER.info("点击测试数据库初始化按钮", extra=LOG_EXTRA) with st.spinner("正在检查数据库..."): result = initialize_database() if result.skipped: LOGGER.info("数据库已存在,无需初始化", extra=LOG_EXTRA) st.success("数据库已存在,检查通过。") else: LOGGER.info("数据库初始化完成,执行语句数=%s", result.executed, extra=LOG_EXTRA) st.success(f"数据库初始化完成,共执行 {result.executed} 条语句。") st.divider() if st.button("测试 TuShare 拉取(示例 2024-01-01 至 2024-01-03)"): LOGGER.info("点击示例 TuShare 拉取按钮", extra=LOG_EXTRA) with st.spinner("正在调用 TuShare 接口..."): try: run_ingestion( FetchJob( name="streamlit_self_test", start=date(2024, 1, 1), end=date(2024, 1, 3), ts_codes=("000001.SZ",), ), include_limits=False, ) LOGGER.info("示例 TuShare 拉取成功", extra=LOG_EXTRA) st.success("TuShare 示例拉取完成,数据已写入数据库。") except Exception as exc: # noqa: BLE001 LOGGER.exception("示例 TuShare 拉取失败", extra=LOG_EXTRA) st.error(f"拉取失败:{exc}") alerts.add_warning("TuShare", "示例拉取失败", str(exc)) _update_dashboard_sidebar() st.info("注意:TuShare 拉取依赖网络与 Token,若环境未配置将出现错误提示。") st.divider() st.subheader("RSS 数据测试") st.write("用于验证 RSS 配置是否能够正常抓取新闻并写入数据库。") rss_url = st.text_input( "测试 RSS 地址", value="https://rsshub.app/cls/depth/1000", help="留空则使用默认配置的全部 RSS 来源。", ).strip() rss_hours = int( st.number_input( "回溯窗口(小时)", min_value=1, max_value=168, value=24, step=6, help="仅抓取最近指定小时内的新闻。", ) ) rss_limit = int( st.number_input( "单源抓取条数", min_value=1, max_value=200, value=50, step=10, ) ) if st.button("运行 RSS 测试"): from app.ingest import rss as rss_ingest LOGGER.info( "点击 RSS 测试按钮 rss_url=%s hours=%s limit=%s", rss_url, rss_hours, rss_limit, extra=LOG_EXTRA, ) with st.spinner("正在抓取 RSS 新闻..."): try: if rss_url: items = rss_ingest.fetch_rss_feed( rss_url, hours_back=rss_hours, max_items=rss_limit, ) count = rss_ingest.save_news_items(items) else: count = rss_ingest.ingest_configured_rss( hours_back=rss_hours, max_items_per_feed=rss_limit, ) st.success(f"RSS 测试完成,新增 {count} 条新闻记录。") except Exception as exc: # noqa: BLE001 LOGGER.exception("RSS 测试失败", extra=LOG_EXTRA) st.error(f"RSS 测试失败:{exc}") alerts.add_warning("RSS", "RSS 测试执行失败", str(exc)) _update_dashboard_sidebar() st.divider() days = int( st.number_input( "检查窗口(天数)", min_value=30, max_value=10950, value=365, step=30, ) ) LOGGER.debug("检查窗口天数=%s", days, extra=LOG_EXTRA) cfg = get_config() force_refresh = st.checkbox( "强制刷新数据(关闭增量跳过)", value=cfg.force_refresh, help="勾选后将重新拉取所选区间全部数据", ) if force_refresh != cfg.force_refresh: cfg.force_refresh = force_refresh LOGGER.info("更新 force_refresh=%s", force_refresh, extra=LOG_EXTRA) save_config() if st.button("执行手动数据同步"): LOGGER.info("点击执行手动数据同步按钮", extra=LOG_EXTRA) progress_bar = st.progress(0.0) status_placeholder = st.empty() log_placeholder = st.empty() messages: list[str] = [] def hook(message: str, value: float) -> None: progress_bar.progress(min(max(value, 0.0), 1.0)) status_placeholder.write(message) messages.append(message) LOGGER.debug("手动数据同步进度:%s -> %.2f", message, value, extra=LOG_EXTRA) with st.spinner("正在执行手动数据同步..."): try: report = run_boot_check( days=days, progress_hook=hook, force_refresh=force_refresh, ) LOGGER.info("手动数据同步成功", extra=LOG_EXTRA) st.success("手动数据同步完成,以下为数据覆盖摘要。") st.json(report.to_dict()) if messages: log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages)) except Exception as exc: # noqa: BLE001 LOGGER.exception("手动数据同步失败", extra=LOG_EXTRA) st.error(f"手动数据同步失败:{exc}") alerts.add_warning("数据同步", "手动数据同步失败", str(exc)) _update_dashboard_sidebar() if messages: log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages)) finally: progress_bar.progress(1.0) st.divider() st.subheader("股票行情可视化") options = _load_stock_options() default_code = options[0] if options else "000001.SZ" if options: selection = st.selectbox("选择股票", options, index=0) ts_code = _parse_ts_code(selection) LOGGER.debug("选择股票:%s", ts_code, extra=LOG_EXTRA) else: ts_code = st.text_input("输入股票代码(如 000001.SZ)", value=default_code).strip().upper() LOGGER.debug("输入股票:%s", ts_code, extra=LOG_EXTRA) viz_col1, viz_col2 = st.columns(2) default_start = date.today() - timedelta(days=180) start_date = viz_col1.date_input("开始日期", value=default_start, key="viz_start") end_date = viz_col2.date_input("结束日期", value=date.today(), key="viz_end") LOGGER.debug("行情可视化日期范围:%s-%s", start_date, end_date, extra=LOG_EXTRA) if start_date > end_date: LOGGER.warning("无效日期范围:%s>%s", start_date, end_date, extra=LOG_EXTRA) st.error("开始日期不能晚于结束日期") return with st.spinner("正在加载行情数据..."): try: df = _load_daily_frame(ts_code, start_date, end_date) except Exception as exc: # noqa: BLE001 LOGGER.exception("加载行情数据失败", extra=LOG_EXTRA) st.error(f"读取数据失败:{exc}") return if df.empty: LOGGER.warning("指定区间无行情数据:%s %s-%s", ts_code, start_date, end_date, extra=LOG_EXTRA) st.warning("未查询到该区间的交易数据,请确认数据库已拉取对应日线。") return price_df = df[["close"]].rename(columns={"close": "收盘价"}) volume_df = df[["vol"]].rename(columns={"vol": "成交量(手)"}) if price_df.shape[0] > 180: sampled = price_df.resample('3D').last().dropna() else: sampled = price_df if volume_df.shape[0] > 180: volume_sampled = volume_df.resample('3D').mean().dropna() else: volume_sampled = volume_df first_close = sampled.iloc[0, 0] last_close = sampled.iloc[-1, 0] delta_abs = last_close - first_close delta_pct = (delta_abs / first_close * 100) if first_close else 0.0 metric_col1, metric_col2, metric_col3 = st.columns(3) metric_col1.metric("最新收盘价", f"{last_close:.2f}", delta=f"{delta_abs:+.2f}") metric_col2.metric("区间涨跌幅", f"{delta_pct:+.2f}%") metric_col3.metric("平均成交量", f"{volume_sampled['成交量(手)'].mean():.0f}") df_reset = df.reset_index().rename(columns={ "trade_date": "交易日", "open": "开盘价", "high": "最高价", "low": "最低价", "close": "收盘价", "vol": "成交量(手)", "amount": "成交额(千元)", }) df_reset["成交额(千元)"] = df_reset["成交额(千元)"] / 1000 candle_fig = go.Figure( data=[ go.Candlestick( x=df_reset["交易日"], open=df_reset["开盘价"], high=df_reset["最高价"], low=df_reset["最低价"], close=df_reset["收盘价"], name="K线", ) ] ) candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10)) st.plotly_chart(candle_fig, use_container_width=True) vol_fig = px.bar( df_reset, x="交易日", y="成交量(手)", labels={"成交量(手)": "成交量(手)"}, title="成交量", ) vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10)) st.plotly_chart(vol_fig, use_container_width=True) amt_fig = px.bar( df_reset, x="交易日", y="成交额(千元)", labels={"成交额(千元)": "成交额(千元)"}, title="成交额", ) amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10)) st.plotly_chart(amt_fig, use_container_width=True) df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str) box_fig = px.box( df_reset, x="月份", y="收盘价", points="outliers", title="月度收盘价分布", ) box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10)) st.plotly_chart(box_fig, use_container_width=True) st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。") st.dataframe(df_reset.tail(20), width='stretch') LOGGER.info("行情可视化完成,展示行数=%s", len(df_reset), extra=LOG_EXTRA) st.divider() st.subheader("LLM 接口测试") st.json(llm_config_snapshot()) llm_prompt = st.text_area("测试 Prompt", value="请概述今天的市场重点。", height=160) system_prompt = st.text_area( "System Prompt (可选)", value="你是一名量化策略研究助手,用简洁中文回答。", height=100, ) if st.button("执行 LLM 测试"): with st.spinner("正在调用 LLM..."): try: response = run_llm(llm_prompt, system=system_prompt or None) except Exception as exc: # noqa: BLE001 LOGGER.exception("LLM 测试失败", extra=LOG_EXTRA) st.error(f"LLM 调用失败:{exc}") else: LOGGER.info("LLM 测试成功", extra=LOG_EXTRA) st.success("LLM 调用成功,以下为返回内容:") st.write(response) def main() -> None: LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA) st.set_page_config(page_title="多智能体个人投资助理", layout="wide") render_global_dashboard() tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"]) LOGGER.debug("Tabs 初始化完成:%s", ["今日计划", "回测与复盘", "数据与设置", "自检测试"], extra=LOG_EXTRA) with tabs[0]: render_today_plan() with tabs[1]: render_backtest() with tabs[2]: render_settings() with tabs[3]: render_tests() if __name__ == "__main__": main()