diff --git a/app/ui/shared.py b/app/ui/shared.py new file mode 100644 index 0000000..3986cd4 --- /dev/null +++ b/app/ui/shared.py @@ -0,0 +1,65 @@ +"""Shared utilities and constants for Streamlit UI views.""" +from __future__ import annotations + +from datetime import date, datetime, timedelta +from typing import Optional + +import streamlit as st + +from app.utils.db import db_session +from app.utils.logging import get_logger + +LOGGER = get_logger(__name__) +LOG_EXTRA = {"stage": "ui"} + + +def get_query_params() -> dict[str, list[str]]: + """Safely read URL query parameters from Streamlit.""" + try: + return dict(st.query_params) + except Exception: # noqa: BLE001 + return {} + + +def set_query_params(**kwargs: object) -> None: + """Update URL query parameters, ignoring failures in unsupported contexts.""" + try: + payload = {k: v for k, v in kwargs.items() if v is not None} + if payload: + st.query_params.update(payload) + except Exception: # noqa: BLE001 + pass + + +def get_latest_trade_date() -> Optional[date]: + """Fetch the most recent trade date from the database.""" + try: + with db_session(read_only=True) as conn: + row = conn.execute( + "SELECT trade_date FROM daily ORDER BY trade_date DESC LIMIT 1" + ).fetchone() + except Exception: # noqa: BLE001 + LOGGER.exception("查询最新交易日失败", extra=LOG_EXTRA) + return None + if not row: + return None + raw_value = row["trade_date"] + if not raw_value: + return None + try: + return datetime.strptime(str(raw_value), "%Y%m%d").date() + except ValueError: + try: + return datetime.fromisoformat(str(raw_value)).date() + except ValueError: + LOGGER.warning("无法解析交易日:%s", raw_value, extra=LOG_EXTRA) + return None + + +def default_backtest_range(window_days: int = 60) -> tuple[date, date]: + """Return a sensible (end, start) date range for backtests.""" + latest = get_latest_trade_date() or date.today() + start = latest - timedelta(days=window_days) + if start > latest: + start = latest + return start, latest diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 7027c3a..84b982c 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -2,2989 +2,76 @@ from __future__ import annotations import sys -from dataclasses import asdict -from datetime import date, datetime, timedelta from pathlib import Path -from typing import Dict, List, Optional + +import streamlit as st ROOT = Path(__file__).resolve().parents[2] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) -import json -from datetime import datetime -import uuid - -import pandas as pd -import plotly.express as px -import plotly.graph_objects as go -import requests -from requests.exceptions import RequestException -import streamlit as st -import numpy as np -from collections import Counter - -from app.agents.base import AgentContext -from app.agents.game import Decision -from app.backtest.engine import BtConfig, run_backtest -from app.ui.portfolio_config import render_portfolio_config -from app.backtest.decision_env import DecisionEnv, ParameterSpec from app.data.schema import initialize_database from app.ingest.checker import run_boot_check -from app.ingest.tushare import FetchJob, run_ingestion -from app.llm.client import llm_config_snapshot, run_llm -from app.llm.metrics import ( - recent_decisions as llm_recent_decisions, - register_listener as register_llm_metrics_listener, - reset as reset_llm_metrics, - snapshot as snapshot_llm_metrics, +from app.ingest.rss import ingest_configured_rss +from app.ui.portfolio_config import render_portfolio_config +from app.ui.shared import LOGGER, LOG_EXTRA +from app.ui.views import ( + render_backtest_review, + render_config_overview, + render_data_settings, + render_global_dashboard, + render_llm_settings, + render_log_viewer, + render_market_visualization, + render_pool_overview, + render_tests, + render_today_plan, ) -from app.utils import alerts -from app.utils.config import ( - ALLOWED_LLM_STRATEGIES, - DEFAULT_LLM_BASE_URLS, - DEFAULT_LLM_MODEL_OPTIONS, - DEFAULT_LLM_MODELS, - DepartmentSettings, - LLMEndpoint, - LLMProvider, - get_config, - save_config, -) -from app.utils.db import db_session -from app.utils.logging import get_logger -from app.utils.portfolio import ( - get_latest_snapshot, - list_investment_pool, - list_positions, - list_recent_trades, -) -from app.agents.registry import default_agents -from app.utils.tuning import log_tuning_result -from app.backtest.engine import BacktestEngine, PortfolioState - - -LOGGER = get_logger(__name__) -LOG_EXTRA = {"stage": "ui"} -_DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result" -_DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results" -_DASHBOARD_CONTAINERS: Optional[tuple[object, object]] = None -_DASHBOARD_ELEMENTS: Optional[Dict[str, object]] = None -_SIDEBAR_LISTENER_ATTACHED = False -# ADD: simple in-memory cache for provider model discovery -_MODEL_CACHE: Dict[str, Dict[str, object]] = {} -_CACHE_TTL_SECONDS = 300 -_WARNINGS_CONTAINER = None -_WARNINGS_PLACEHOLDER = None - -# ADD: query param helpers -def _get_query_params() -> Dict[str, List[str]]: - try: - return dict(st.query_params) - except Exception: - return {} - -def _set_query_params(**kwargs: object) -> None: - try: - payload = {k: v for k, v in kwargs.items() if v is not None} - if payload: - st.query_params.update(payload) - except Exception: - pass - - -def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None: - try: - _update_dashboard_sidebar(metrics) - except Exception: # noqa: BLE001 - LOGGER.debug("侧边栏监听器刷新失败", exc_info=True, extra=LOG_EXTRA) - - -def render_global_dashboard() -> None: - """Render a persistent sidebar with realtime LLM stats and recent decisions.""" - - global _DASHBOARD_CONTAINERS - global _DASHBOARD_ELEMENTS - global _SIDEBAR_LISTENER_ATTACHED - global _WARNINGS_CONTAINER - global _WARNINGS_PLACEHOLDER - - # ADD: warning badge on top - warnings = alerts.get_warnings() - badge = f" ({len(warnings)})" if warnings else "" - st.sidebar.header(f"系统监控{badge}") - - metrics_container = st.sidebar.container() - decisions_container = st.sidebar.container() - _WARNINGS_CONTAINER = st.sidebar.container() - _WARNINGS_PLACEHOLDER = st.sidebar.empty() - _DASHBOARD_CONTAINERS = (metrics_container, decisions_container) - _DASHBOARD_ELEMENTS = _ensure_dashboard_elements(metrics_container, decisions_container) - if not _SIDEBAR_LISTENER_ATTACHED: - register_llm_metrics_listener(_sidebar_metrics_listener) - _SIDEBAR_LISTENER_ATTACHED = True - _update_dashboard_sidebar() - - -def _update_dashboard_sidebar( - metrics: Optional[Dict[str, object]] = None, -) -> None: - global _DASHBOARD_CONTAINERS - global _DASHBOARD_ELEMENTS - global _WARNINGS_CONTAINER - global _WARNINGS_PLACEHOLDER - - containers = _DASHBOARD_CONTAINERS - if not containers: - return - metrics_container, decisions_container = containers - elements = _DASHBOARD_ELEMENTS - if elements is None: - elements = _ensure_dashboard_elements(metrics_container, decisions_container) - _DASHBOARD_ELEMENTS = elements - - if metrics is None: - metrics = snapshot_llm_metrics() - - elements["metrics_calls"].metric("LLM 调用", metrics.get("total_calls", 0)) - elements["metrics_prompt"].metric("Prompt Tokens", metrics.get("total_prompt_tokens", 0)) - elements["metrics_completion"].metric( - "Completion Tokens", metrics.get("total_completion_tokens", 0) - ) - - provider_calls = metrics.get("provider_calls", {}) - model_calls = metrics.get("model_calls", {}) - provider_placeholder = elements["provider_distribution"] - provider_placeholder.empty() - if provider_calls: - provider_placeholder.json(provider_calls) - else: - provider_placeholder.info("暂无 Provider 分布数据。") - - model_placeholder = elements["model_distribution"] - model_placeholder.empty() - if model_calls: - model_placeholder.json(model_calls) - else: - model_placeholder.info("暂无模型分布数据。") - - decisions = metrics.get("recent_decisions") or llm_recent_decisions(10) - if decisions: - lines = [] - for record in reversed(decisions[-10:]): - ts_code = record.get("ts_code") - trade_date = record.get("trade_date") - action = record.get("action") - confidence = record.get("confidence", 0.0) - summary = record.get("summary") - line = f"**{trade_date} {ts_code}** → {action} (置信度 {confidence:.2f})" - if summary: - line += f"\n{summary}" - lines.append(line) - decisions_placeholder = elements["decisions_list"] - decisions_placeholder.empty() - decisions_placeholder.markdown("\n\n".join(lines), unsafe_allow_html=True) - else: - decisions_placeholder = elements["decisions_list"] - decisions_placeholder.empty() - decisions_placeholder.info("暂无决策记录。执行回测或实时评估后可在此查看。") - # Render warnings section in-place (clear then write) - if _WARNINGS_PLACEHOLDER is not None: - _WARNINGS_PLACEHOLDER.empty() - with _WARNINGS_PLACEHOLDER.container(): - st.subheader("数据告警") - warn_list = alerts.get_warnings() - if warn_list: - lines = [] - for warning in warn_list[-10:]: - detail = warning.get("detail") - appendix = f" {detail}" if detail else "" - lines.append( - f"- **{warning['source']}** {warning['message']}{appendix}\n{warning['timestamp']}" - ) - st.markdown("\n".join(lines), unsafe_allow_html=True) - btn_cols = st.columns([1,1]) - if btn_cols[0].button("清除数据告警", key="clear_data_alerts_sibling"): - alerts.clear_warnings() - _update_dashboard_sidebar() - try: - st.download_button( - "导出告警(JSON)", - data=json.dumps(warn_list, ensure_ascii=False, indent=2), - file_name="data_warnings.json", - mime="application/json", - key="dl_warnings_json_sibling", - ) - except Exception: - pass - else: - st.info("暂无数据告警。") - - -def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[str, object]: - metrics_container.header("系统监控") - col_a, col_b, col_c = metrics_container.columns(3) - metrics_calls = col_a.empty() - metrics_prompt = col_b.empty() - metrics_completion = col_c.empty() - distribution_expander = metrics_container.expander("调用分布", expanded=False) - provider_distribution = distribution_expander.empty() - model_distribution = distribution_expander.empty() - - decisions_container.subheader("最新决策") - decisions_list = decisions_container.empty() - - elements = { - "metrics_calls": metrics_calls, - "metrics_prompt": metrics_prompt, - "metrics_completion": metrics_completion, - "provider_distribution": provider_distribution, - "model_distribution": model_distribution, - "decisions_list": decisions_list, - } - return elements - -def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]: - """Attempt to query provider API and return available model ids.""" - - base_url = (base_override or provider.base_url or DEFAULT_LLM_BASE_URLS.get(provider.key, "")).strip() - if not base_url: - return [], "请先填写 Base URL" - timeout = float(provider.default_timeout or 30.0) - mode = provider.mode or ("ollama" if provider.key == "ollama" else "openai") - - # ADD: simple cache by provider+base URL - cache_key = f"{provider.key}|{base_url}" - now = datetime.now() - cached = _MODEL_CACHE.get(cache_key) - if cached: - ts = cached.get("ts") - if isinstance(ts, float) and (now.timestamp() - ts) < _CACHE_TTL_SECONDS: - models = list(cached.get("models") or []) - return models, None - - try: - if mode == "ollama": - url = base_url.rstrip('/') + "/api/tags" - response = requests.get(url, timeout=timeout) - response.raise_for_status() - data = response.json() - models = [] - for item in data.get("models", []) or data.get("data", []): - name = item.get("name") or item.get("model") or item.get("tag") - if name: - models.append(str(name).strip()) - _MODEL_CACHE[cache_key] = {"ts": now.timestamp(), "models": sorted(set(models))} - return sorted(set(models)), None - - api_key = (api_override or provider.api_key or "").strip() - if not api_key: - return [], "缺少 API Key" - url = base_url.rstrip('/') + "/v1/models" - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - } - response = requests.get(url, headers=headers, timeout=timeout) - response.raise_for_status() - payload = response.json() - models = [ - str(item.get("id")).strip() - for item in payload.get("data", []) - if item.get("id") - ] - _MODEL_CACHE[cache_key] = {"ts": now.timestamp(), "models": sorted(set(models))} - return sorted(set(models)), None - except RequestException as exc: # noqa: BLE001 - return [], f"HTTP 错误:{exc}" - except Exception as exc: # noqa: BLE001 - return [], f"解析失败:{exc}" - -def _load_stock_options(limit: int = 500) -> list[str]: - try: - with db_session(read_only=True) as conn: - rows = conn.execute( - "SELECT ts_code, name FROM stock_basic WHERE list_status = 'L' ORDER BY ts_code" - ).fetchall() - except Exception: - LOGGER.exception("加载股票列表失败", extra=LOG_EXTRA) - return [] - options: list[str] = [] - for row in rows[:limit]: - code = row["ts_code"] - name = row["name"] or "" - label = f"{code} | {name}" if name else code - options.append(label) - LOGGER.info("加载股票选项完成,数量=%s", len(options), extra=LOG_EXTRA) - return options - - -def _parse_ts_code(selection: str) -> str: - return selection.split(' | ')[0].strip().upper() - - -def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame: - LOGGER.info( - "加载行情数据:ts_code=%s start=%s end=%s", - ts_code, - start, - end, - extra=LOG_EXTRA, - ) - start_str = start.strftime('%Y%m%d') - end_str = end.strftime('%Y%m%d') - range_query = ( - "SELECT trade_date, open, high, low, close, vol, amount " - "FROM daily WHERE ts_code = ? AND trade_date BETWEEN ? AND ? ORDER BY trade_date" - ) - fallback_query = ( - "SELECT trade_date, open, high, low, close, vol, amount " - "FROM daily WHERE ts_code = ? ORDER BY trade_date DESC LIMIT 200" - ) - with db_session(read_only=True) as conn: - df = pd.read_sql_query(range_query, conn, params=(ts_code, start_str, end_str)) - if df.empty: - df = pd.read_sql_query(fallback_query, conn, params=(ts_code,)) - if df.empty: - LOGGER.warning( - "行情数据为空:ts_code=%s start=%s end=%s", - ts_code, - start, - end, - extra=LOG_EXTRA, - ) - return df - df = df.sort_values('trade_date') - df['trade_date'] = pd.to_datetime(df['trade_date']) - df.set_index('trade_date', inplace=True) - LOGGER.info("行情数据加载完成:条数=%s", len(df), extra=LOG_EXTRA) - return df - - -def _get_latest_trade_date() -> Optional[date]: - try: - with db_session(read_only=True) as conn: - row = conn.execute( - "SELECT trade_date FROM daily ORDER BY trade_date DESC LIMIT 1" - ).fetchone() - except Exception: # noqa: BLE001 - LOGGER.exception("查询最新交易日失败", extra=LOG_EXTRA) - return None - if not row: - return None - raw_value = row["trade_date"] - if not raw_value: - return None - try: - return datetime.strptime(str(raw_value), "%Y%m%d").date() - except ValueError: - try: - return datetime.fromisoformat(str(raw_value)).date() - except ValueError: - LOGGER.warning("无法解析交易日:%s", raw_value, extra=LOG_EXTRA) - return None - - -def _default_backtest_range(window_days: int = 60) -> tuple[date, date]: - latest = _get_latest_trade_date() or date.today() - start = latest - timedelta(days=window_days) - if start > latest: - start = latest - return start, latest - - -def render_today_plan() -> None: - LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA) - st.header("今日计划") - latest_trade_date = _get_latest_trade_date() - if latest_trade_date: - st.caption(f"最新交易日:{latest_trade_date.isoformat()}(统计数据请见左侧系统监控)") - else: - st.caption("统计与决策概览现已移至左侧'系统监控'侧栏。") - try: - with db_session(read_only=True) as conn: - date_rows = conn.execute( - """ - SELECT DISTINCT trade_date - FROM agent_utils - ORDER BY trade_date DESC - LIMIT 30 - """ - ).fetchall() - except Exception: # noqa: BLE001 - LOGGER.exception("加载 agent_utils 失败", extra=LOG_EXTRA) - st.warning("暂未写入部门/代理决策,请先运行回测或策略评估流程。") - return - trade_dates = [row["trade_date"] for row in date_rows] - if not trade_dates: - st.info("暂无决策记录,完成一次回测后即可在此查看部门意见与投票结果。") - return - q = _get_query_params() - default_trade_date = q.get("date", [trade_dates[0]])[0] - try: - default_idx = trade_dates.index(default_trade_date) - except ValueError: - default_idx = 0 - trade_date = st.selectbox("交易日", trade_dates, index=default_idx) - with db_session(read_only=True) as conn: - code_rows = conn.execute( - """ - SELECT DISTINCT ts_code - FROM agent_utils - WHERE trade_date = ? - ORDER BY ts_code - """, - (trade_date,), - ).fetchall() - symbols = [row["ts_code"] for row in code_rows] - detail_tab, assistant_tab = st.tabs(["标的详情", "投资助理模式"]) - with assistant_tab: - st.info("已开启投资助理模式:以下内容为组合级(去标的)建议,不包含任何具体标的代码。") - try: - candidates = list_investment_pool(trade_date=trade_date) - if candidates: - scores = [float(item.score or 0.0) for item in candidates] - statuses = [item.status or "UNKNOWN" for item in candidates] - tags = [] - rationales = [] - for item in candidates: - if getattr(item, "tags", None): - tags.extend(item.tags) - if getattr(item, "rationale", None): - rationales.append(str(item.rationale)) - cnt = Counter(statuses) - tag_cnt = Counter(tags) - st.subheader("候选池聚合概览(已匿名化)") - col_a, col_b, col_c = st.columns(3) - col_a.metric("候选数", f"{len(candidates)}") - col_b.metric("平均评分", f"{np.mean(scores):.3f}" if scores else "-") - col_c.metric("中位评分", f"{np.median(scores):.3f}" if scores else "-") - - st.write("状态分布:") - st.json(dict(cnt)) - - if tag_cnt: - st.write("常见标签(示例):") - st.json(dict(tag_cnt.most_common(10))) - - if rationales: - st.write("汇总理由(节选,不含代码):") - # 显示前三条去重理由摘要 - seen = set() - excerpts = [] - for r in rationales: - s = r.strip() - if not s: - continue - if s in seen: - continue - seen.add(s) - excerpts.append(s) - if len(excerpts) >= 3: - break - for idx, ex in enumerate(excerpts, start=1): - st.markdown(f"**理由 {idx}:** {ex}") - - # 简单生成组合级建议(示例逻辑,可后续替换为更复杂策略) - avg_score = float(np.mean(scores)) if scores else 0.0 - # 建议使用现金比例:基线 10%,根据平均评分上下调整,限制在[0, 30%] - suggest_pct = max(0.0, min(0.3, 0.10 + (avg_score - 0.5) * 0.2)) - st.subheader("组合级建议(不指定标的)") - st.write( - f"基于候选池平均评分 {avg_score:.3f},建议今日用于新增买入的现金比例约为 {suggest_pct:.0%}。" - ) - st.write( - "建议分配思路:在候选池中挑选若干得分较高的标的按目标权重等比例分配,或以分批买入的方式分摊入场时点。" - ) - if st.button("生成组合级操作建议(仅输出,不执行)"): - st.success("已生成组合级建议(仅供参考)。") - st.write({ - "候选数": len(candidates), - "平均评分": avg_score, - "建议新增买入比例": f"{suggest_pct:.0%}", - }) - else: - st.info("所选交易日暂无候选投资池数据。") - except Exception: - LOGGER.exception("加载候选池聚合信息失败", extra=LOG_EXTRA) - st.error("加载候选池数据时发生错误。") - with detail_tab: - if not symbols: - st.info("所选交易日暂无 agent_utils 记录。") - else: - default_ts = q.get("code", [symbols[0]])[0] - try: - default_ts_idx = symbols.index(default_ts) - except ValueError: - default_ts_idx = 0 - ts_code = st.selectbox("标的", symbols, index=default_ts_idx) - batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[]) - - if st.button("一键重评估所有标的", type="primary", width='stretch'): - with st.spinner("正在对所有标的进行重评估,请稍候..."): - try: - # 解析交易日 - trade_date_obj = None - try: - trade_date_obj = date.fromisoformat(str(trade_date)) - except Exception: - try: - trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() - except Exception: - pass - if trade_date_obj is None: - raise ValueError(f"无法解析交易日:{trade_date}") - - progress = st.progress(0.0) - changes_all = [] - success_count = 0 - error_count = 0 - - # 遍历所有标的 - for idx, code in enumerate(symbols, start=1): - try: - # 保存重评估前的状态 - with db_session(read_only=True) as conn: - before_rows = conn.execute( - "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", - (trade_date, code), - ).fetchall() - before_map = {row["agent"]: row["action"] for row in before_rows} - - # 执行重评估 - cfg = BtConfig( - id="reeval_ui_all", - name="UI All Re-eval", - start_date=trade_date_obj, - end_date=trade_date_obj, - universe=[code], - params={}, - ) - engine = BacktestEngine(cfg) - state = PortfolioState() - _ = engine.simulate_day(trade_date_obj, state) - - # 检查变化 - with db_session(read_only=True) as conn: - after_rows = conn.execute( - "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", - (trade_date, code), - ).fetchall() - for row in after_rows: - agent = row["agent"] - new_action = row["action"] - old_action = before_map.get(agent) - if new_action != old_action: - changes_all.append({"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action}) - success_count += 1 - except Exception as e: - LOGGER.exception(f"重评估 {code} 失败", extra=LOG_EXTRA) - error_count += 1 - - # 更新进度 - progress.progress(idx / len(symbols)) - - # 显示结果 - if error_count > 0: - st.error(f"一键重评估完成:成功 {success_count} 个,失败 {error_count} 个") - else: - st.success(f"一键重评估完成:所有 {success_count} 个标的重评估成功") - - # 显示变更记录 - if changes_all: - st.write("检测到以下动作变更:") - st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch') - - # 刷新页面数据 - st.rerun() - except Exception as exc: - LOGGER.exception("一键重评估失败", extra=LOG_EXTRA) - st.error(f"一键重评估执行过程中发生错误:{exc}") - _set_query_params(date=str(trade_date), code=str(ts_code)) - with db_session(read_only=True) as conn: - rows = conn.execute( - """ - SELECT agent, action, utils, feasible, weight - FROM agent_utils - WHERE trade_date = ? AND ts_code = ? - ORDER BY CASE WHEN agent = 'global' THEN 1 ELSE 0 END, agent - """, - (trade_date, ts_code), - ).fetchall() - if not rows: - st.info("未查询到详细决策记录,稍后再试。") - return - try: - feasible_actions = json.loads(rows[0]["feasible"] or "[]") - except (KeyError, TypeError, json.JSONDecodeError): - feasible_actions = [] - global_info = None - dept_records: List[Dict[str, object]] = [] - dept_details: Dict[str, Dict[str, object]] = {} - agent_records: List[Dict[str, object]] = [] - for item in rows: - agent_name = item["agent"] - action = item["action"] - weight = float(item["weight"] or 0.0) - try: - utils = json.loads(item["utils"] or "{}") - except json.JSONDecodeError: - utils = {} - - if agent_name == "global": - global_info = { - "action": action, - "confidence": float(utils.get("_confidence", 0.0)), - "target_weight": float(utils.get("_target_weight", 0.0)), - "department_votes": utils.get("_department_votes", {}), - "requires_review": bool(utils.get("_requires_review", False)), - "scope_values": utils.get("_scope_values", {}), - "close_series": utils.get("_close_series", []), - "turnover_series": utils.get("_turnover_series", []), - "department_supplements": utils.get("_department_supplements", {}), - "department_dialogue": utils.get("_department_dialogue", {}), - "department_telemetry": utils.get("_department_telemetry", {}), - } - continue - - if agent_name.startswith("dept_"): - code = agent_name.split("dept_", 1)[-1] - signals = utils.get("_signals", []) - risks = utils.get("_risks", []) - supplements = utils.get("_supplements", []) - dialogue = utils.get("_dialogue", []) - telemetry = utils.get("_telemetry", {}) - dept_records.append( - { - "部门": code, - "行动": action, - "信心": float(utils.get("_confidence", 0.0)), - "权重": weight, - "摘要": utils.get("_summary", ""), - "核心信号": ";".join(signals) if isinstance(signals, list) else signals, - "风险提示": ";".join(risks) if isinstance(risks, list) else risks, - "补充次数": len(supplements) if isinstance(supplements, list) else 0, - } - ) - dept_details[code] = { - "supplements": supplements if isinstance(supplements, list) else [], - "dialogue": dialogue if isinstance(dialogue, list) else [], - "summary": utils.get("_summary", ""), - "signals": signals, - "risks": risks, - "telemetry": telemetry if isinstance(telemetry, dict) else {}, - } - else: - score_map = { - key: float(val) - for key, val in utils.items() - if not str(key).startswith("_") - } - agent_records.append( - { - "代理": agent_name, - "建议动作": action, - "权重": weight, - "SELL": score_map.get("SELL", 0.0), - "HOLD": score_map.get("HOLD", 0.0), - "BUY_S": score_map.get("BUY_S", 0.0), - "BUY_M": score_map.get("BUY_M", 0.0), - "BUY_L": score_map.get("BUY_L", 0.0), - } - ) - if feasible_actions: - st.caption(f"可行操作集合:{', '.join(feasible_actions)}") - st.subheader("全局策略") - if global_info: - col1, col2, col3 = st.columns(3) - col1.metric("最终行动", global_info["action"]) - col2.metric("信心", f"{global_info['confidence']:.2f}") - col3.metric("目标权重", f"{global_info['target_weight']:+.2%}") - if global_info["department_votes"]: - st.json(global_info["department_votes"]) - if global_info["requires_review"]: - st.warning("部门分歧较大,已标记为需人工复核。") - with st.expander("基础上下文数据", expanded=False): - # ADD: export buttons - scope = global_info.get("scope_values") or {} - close_series = global_info.get("close_series") or [] - turnover_series = global_info.get("turnover_series") or [] - st.write("最新字段:") - if scope: - st.json(scope) - st.download_button( - "下载字段(JSON)", - data=json.dumps(scope, ensure_ascii=False, indent=2), - file_name=f"{ts_code}_{trade_date}_scope.json", - mime="application/json", - key="dl_scope_json", - ) - if close_series: - st.write("收盘价时间序列 (最近窗口):") - st.json(close_series) - try: - import io, csv - buf = io.StringIO() - writer = csv.writer(buf) - writer.writerow(["trade_date", "close"]) - for dt, val in close_series: - writer.writerow([dt, val]) - st.download_button( - "下载收盘价(CSV)", - data=buf.getvalue(), - file_name=f"{ts_code}_{trade_date}_close_series.csv", - mime="text/csv", - key="dl_close_csv", - ) - except Exception: - pass - if turnover_series: - st.write("换手率时间序列 (最近窗口):") - st.json(turnover_series) - dept_sup = global_info.get("department_supplements") or {} - dept_dialogue = global_info.get("department_dialogue") or {} - dept_telemetry = global_info.get("department_telemetry") or {} - if dept_sup or dept_dialogue: - with st.expander("部门补数与对话记录", expanded=False): - if dept_sup: - st.write("补充数据:") - st.json(dept_sup) - if dept_dialogue: - st.write("对话片段:") - st.json(dept_dialogue) - if dept_telemetry: - with st.expander("部门 LLM 元数据", expanded=False): - st.json(dept_telemetry) - else: - st.info("暂未写入全局策略摘要。") - st.subheader("部门意见") - if dept_records: - # ADD: keyword filter for department summaries - keyword = st.text_input("筛选摘要/信号关键词", value="") - filtered = dept_records - if keyword.strip(): - kw = keyword.strip() - filtered = [ - item for item in dept_records - if kw in str(item.get("摘要", "")) or kw in str(item.get("核心信号", "")) - ] - # ADD: confidence filter and sort - min_conf = st.slider("最低信心过滤", 0.0, 1.0, 0.0, 0.05) - sort_col = st.selectbox("排序列", ["信心", "权重"], index=0) - filtered = [row for row in filtered if float(row.get("信心", 0.0)) >= min_conf] - filtered = sorted(filtered, key=lambda r: float(r.get(sort_col, 0.0)), reverse=True) - dept_df = pd.DataFrame(filtered) - st.dataframe(dept_df, width='stretch', hide_index=True) - try: - st.download_button( - "下载部门意见(CSV)", - data=dept_df.to_csv(index=False), - file_name=f"{trade_date}_{ts_code}_departments.csv", - mime="text/csv", - key="dl_dept_csv", - ) - except Exception: - pass - for code, details in dept_details.items(): - with st.expander(f"{code} 补充详情", expanded=False): - supplements = details.get("supplements", []) - dialogue = details.get("dialogue", []) - if supplements: - st.write("补充数据:") - st.json(supplements) - else: - st.caption("无补充数据请求。") - if dialogue: - st.write("对话记录:") - for idx, line in enumerate(dialogue, start=1): - st.markdown(f"**回合 {idx}:** {line}") - else: - st.caption("无额外对话。") - telemetry = details.get("telemetry") or {} - if telemetry: - st.write("LLM 元数据:") - st.json(telemetry) - else: - st.info("暂无部门记录。") - st.subheader("代理评分") - if agent_records: - # ADD: sorting and CSV export for agents - sort_agent_by = st.selectbox( - "代理排序", - ["权重", "SELL", "HOLD", "BUY_S", "BUY_M", "BUY_L"], - index=1, - ) - agent_df = pd.DataFrame(agent_records) - if sort_agent_by in agent_df.columns: - agent_df = agent_df.sort_values(sort_agent_by, ascending=False) - st.dataframe(agent_df, width='stretch', hide_index=True) - try: - st.download_button( - "下载代理评分(CSV)", - data=agent_df.to_csv(index=False), - file_name=f"{trade_date}_{ts_code}_agents.csv", - mime="text/csv", - key="dl_agent_csv", - ) - except Exception: - pass - else: - st.info("暂无基础代理评分。") - st.divider() - st.subheader("相关新闻") - try: - with db_session(read_only=True) as conn: - # 解析当前trade_date为datetime对象 - try: - trade_date_obj = date.fromisoformat(str(trade_date)) - except: - try: - trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() - except: - # 如果解析失败,使用当前日期向前推7天 - trade_date_obj = date.today() - timedelta(days=7) - - # 查询近7天内与当前标的相关的新闻,按发布时间降序排列 - news_query = """ - SELECT id, title, source, pub_time, sentiment, heat, entities - FROM news - WHERE ts_code = ? AND pub_time >= ? - ORDER BY pub_time DESC - LIMIT 10 - """ - # 计算7天前的日期字符串 - seven_days_ago = (trade_date_obj - timedelta(days=7)).strftime("%Y-%m-%d") - news_rows = conn.execute(news_query, (ts_code, seven_days_ago)).fetchall() - - if news_rows: - news_data = [] - for row in news_rows: - # 解析entities字段获取更多信息 - entities_info = {} - try: - if row["entities"]: - entities_info = json.loads(row["entities"]) - except (json.JSONDecodeError, TypeError): - pass - - # 准备新闻数据 - news_item = { - "标题": row["title"], - "来源": row["source"], - "发布时间": row["pub_time"], - "情感指数": f"{row['sentiment']:.2f}" if row["sentiment"] is not None else "-", - "热度评分": f"{row['heat']:.2f}" if row["heat"] is not None else "-" - } - - # 如果有行业信息,添加到展示数据中 - industries = entities_info.get("industries", []) - if industries: - news_item["相关行业"] = "、".join(industries[:3]) # 只显示前3个行业 - - news_data.append(news_item) - - # 显示新闻表格 - news_df = pd.DataFrame(news_data) - # 确保所有列都是字符串类型,避免PyArrow序列化错误 - for col in news_df.columns: - news_df[col] = news_df[col].astype(str) - st.dataframe(news_df, width='stretch', hide_index=True) - - # 添加新闻详情展开视图 - st.write("详细新闻内容:") - for idx, row in enumerate(news_rows): - with st.expander(f"{idx+1}. {row['title']}", expanded=False): - st.write(f"**来源:** {row['source']}") - st.write(f"**发布时间:** {row['pub_time']}") - - # 解析entities获取更多详细信息 - entities_info = {} - try: - if row["entities"]: - entities_info = json.loads(row["entities"]) - except (json.JSONDecodeError, TypeError): - pass - - # 显示情感和热度信息 - sentiment_display = f"{row['sentiment']:.2f}" if row["sentiment"] is not None else "-" - heat_display = f"{row['heat']:.2f}" if row["heat"] is not None else "-" - st.write(f"**情感指数:** {sentiment_display} | **热度评分:** {heat_display}") - - # 显示行业信息 - industries = entities_info.get("industries", []) - if industries: - st.write(f"**相关行业:** {'、'.join(industries)}") - - # 显示重要关键词 - important_keywords = entities_info.get("important_keywords", []) - if important_keywords: - st.write(f"**重要关键词:** {'、'.join(important_keywords)}") - - # 显示URL链接(如果有) - url = entities_info.get("source_url", "") - if url: - st.markdown(f"[查看原文]({url})", unsafe_allow_html=True) - else: - st.info(f"近7天内暂无关于 {ts_code} 的新闻。") - except Exception as e: - LOGGER.exception("获取新闻数据失败", extra=LOG_EXTRA) - st.error(f"获取新闻数据时发生错误:{e}") - st.divider() - st.info("投资池与仓位概览已移至单独页面。请在侧边或页面导航中选择“投资池/仓位”以查看详细信息。") - st.divider() - st.subheader("策略重评估") - st.caption("对当前选中的交易日与标的,立即触发一次策略评估并回写 agent_utils。") - cols_re = st.columns([1,1]) - if cols_re[0].button("对该标的重评估", key="reevaluate_current_symbol"): - with st.spinner("正在重评估..."): - try: - trade_date_obj = None - try: - trade_date_obj = date.fromisoformat(str(trade_date)) - except Exception: - try: - trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() - except Exception: - pass - if trade_date_obj is None: - raise ValueError(f"无法解析交易日:{trade_date}") - # snapshot before - with db_session(read_only=True) as conn: - before_rows = conn.execute( - """ - SELECT agent, action, utils FROM agent_utils - WHERE trade_date = ? AND ts_code = ? - """, - (trade_date, ts_code), - ).fetchall() - before_map = {row["agent"]: (row["action"], row["utils"]) for row in before_rows} - cfg = BtConfig( - id="reeval_ui", - name="UI Re-evaluation", - start_date=trade_date_obj, - end_date=trade_date_obj, - universe=[ts_code], - params={}, - ) - engine = BacktestEngine(cfg) - state = PortfolioState() - _ = engine.simulate_day(trade_date_obj, state) - # compare after - with db_session(read_only=True) as conn: - after_rows = conn.execute( - """ - SELECT agent, action, utils FROM agent_utils - WHERE trade_date = ? AND ts_code = ? - """, - (trade_date, ts_code), - ).fetchall() - changes = [] - for row in after_rows: - agent = row["agent"] - new_action = row["action"] - old_action, _old_utils = before_map.get(agent, (None, None)) - if new_action != old_action: - changes.append({"代理": agent, "原动作": old_action, "新动作": new_action}) - if changes: - st.success("重评估完成,检测到动作变更:") - st.dataframe(pd.DataFrame(changes), hide_index=True, width='stretch') - else: - st.success("重评估完成,无动作变更。") - st.rerun() - except Exception as exc: # noqa: BLE001 - LOGGER.exception("重评估失败", extra=LOG_EXTRA) - st.error(f"重评估失败:{exc}") - if cols_re[1].button("批量重评估(所选)", key="reevaluate_batch", disabled=not batch_symbols): - with st.spinner("批量重评估执行中..."): - try: - trade_date_obj = None - try: - trade_date_obj = date.fromisoformat(str(trade_date)) - except Exception: - try: - trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() - except Exception: - pass - if trade_date_obj is None: - raise ValueError(f"无法解析交易日:{trade_date}") - progress = st.progress(0.0) - changes_all: List[Dict[str, object]] = [] - for idx, code in enumerate(batch_symbols, start=1): - with db_session(read_only=True) as conn: - before_rows = conn.execute( - "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", - (trade_date, code), - ).fetchall() - before_map = {row["agent"]: row["action"] for row in before_rows} - cfg = BtConfig( - id="reeval_ui_batch", - name="UI Batch Re-eval", - start_date=trade_date_obj, - end_date=trade_date_obj, - universe=[code], - params={}, - ) - engine = BacktestEngine(cfg) - state = PortfolioState() - _ = engine.simulate_day(trade_date_obj, state) - with db_session(read_only=True) as conn: - after_rows = conn.execute( - "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", - (trade_date, code), - ).fetchall() - for row in after_rows: - agent = row["agent"] - new_action = row["action"] - old_action = before_map.get(agent) - if new_action != old_action: - changes_all.append({"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action}) - progress.progress(idx / max(1, len(batch_symbols))) - st.success("批量重评估完成。") - if changes_all: - st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch') - st.rerun() - except Exception as exc: # noqa: BLE001 - LOGGER.exception("批量重评估失败", extra=LOG_EXTRA) - st.error(f"批量重评估失败:{exc}") - - - -def render_pool_overview() -> None: - """单独的投资池与仓位概览页面(从今日计划中提取)。""" - LOGGER.info("渲染投资池与仓位概览页面", extra=LOG_EXTRA) - st.header("投资池与仓位概览") - - snapshot = get_latest_snapshot() - if snapshot: - col_a, col_b, col_c = st.columns(3) - if snapshot.total_value is not None: - col_a.metric("组合净值", f"{snapshot.total_value:,.2f}") - if snapshot.cash is not None: - col_b.metric("现金余额", f"{snapshot.cash:,.2f}") - if snapshot.invested_value is not None: - col_c.metric("持仓市值", f"{snapshot.invested_value:,.2f}") - detail_cols = st.columns(4) - if snapshot.unrealized_pnl is not None: - detail_cols[0].metric("浮盈", f"{snapshot.unrealized_pnl:,.2f}") - if snapshot.realized_pnl is not None: - detail_cols[1].metric("已实现盈亏", f"{snapshot.realized_pnl:,.2f}") - if snapshot.net_flow is not None: - detail_cols[2].metric("净流入", f"{snapshot.net_flow:,.2f}") - if snapshot.exposure is not None: - detail_cols[3].metric("风险敞口", f"{snapshot.exposure:.2%}") - if snapshot.notes: - st.caption(f"备注:{snapshot.notes}") - else: - st.info("暂无组合快照,请在执行回测或实盘同步后写入 portfolio_snapshots。") - - # 候选池 - try: - # trade_date param optional; use latest available if not provided - latest_date = _get_latest_trade_date() - candidates = list_investment_pool(trade_date=latest_date) - except Exception: - LOGGER.exception("加载候选池失败", extra=LOG_EXTRA) - candidates = [] - - if candidates: - candidate_df = pd.DataFrame( - [ - { - "交易日": item.trade_date, - "代码": item.ts_code, - "评分": item.score, - "状态": item.status, - "标签": "、".join(item.tags) if item.tags else "-", - "理由": item.rationale or "", - } - for item in candidates - ] - ) - st.write("候选投资池:") - st.dataframe(candidate_df, width='stretch', hide_index=True) - else: - st.caption("候选投资池暂无数据。") - - # 持仓 - positions = list_positions(active_only=False) - if positions: - position_df = pd.DataFrame( - [ - { - "ID": pos.id, - "代码": pos.ts_code, - "开仓日": pos.opened_date, - "平仓日": pos.closed_date or "-", - "状态": pos.status, - "数量": pos.quantity, - "成本": pos.cost_price, - "现价": pos.market_price, - "市值": pos.market_value, - "浮盈": pos.unrealized_pnl, - "已实现": pos.realized_pnl, - "目标权重": pos.target_weight, - } - for pos in positions - ] - ) - st.write("组合持仓:") - st.dataframe(position_df, width='stretch', hide_index=True) - else: - st.caption("组合持仓暂无记录。") - - trades = list_recent_trades(limit=20) - if trades: - trades_df = pd.DataFrame(trades) - st.write("近期成交:") - st.dataframe(trades_df, width='stretch', hide_index=True) - else: - st.caption("近期成交暂无记录。") - - st.caption("数据来源:agent_utils、investment_pool、portfolio_positions、portfolio_trades、portfolio_snapshots。") - - if st.button("执行对比", type="secondary"): - with st.spinner("执行日志对比分析中..."): - try: - with db_session(read_only=True) as conn: - # 获取两个日期的日志统计 - query_date1 = f"{compare_date1.isoformat()}T00:00:00Z" - query_date2 = f"{compare_date1.isoformat()}T23:59:59Z" - logs1 = conn.execute( - "SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level", - (query_date1, query_date2) - ).fetchall() - - query_date3 = f"{compare_date2.isoformat()}T00:00:00Z" - query_date4 = f"{compare_date2.isoformat()}T23:59:59Z" - logs2 = conn.execute( - "SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level", - (query_date3, query_date4) - ).fetchall() - - # 转换为DataFrame并可视化 - df1 = pd.DataFrame(logs1, columns=["level", "count"]) - df1["date"] = compare_date1.strftime("%Y-%m-%d") - df2 = pd.DataFrame(logs2, columns=["level", "count"]) - df2["date"] = compare_date2.strftime("%Y-%m-%d") - - # 确保所有列的数据类型一致,避免PyArrow序列化错误 - for df in [df1, df2]: - for col in df.columns: - if col != "level": # level列保持字符串类型 - df[col] = df[col].astype(object) - - compare_df = pd.concat([df1, df2]) - - # 绘制对比图表 - fig = px.bar( - compare_df, - x="level", - y="count", - color="date", - barmode="group", - title=f"日志级别分布对比 ({compare_date1} vs {compare_date2})" - ) - st.plotly_chart(fig, width='stretch') - - # 显示详细对比表格 - st.write("日志统计对比:") - # 使用不含连字符的日期格式作为列名后缀,避免Arrow类型转换错误 - date1_str = compare_date1.strftime("%Y%m%d") - date2_str = compare_date2.strftime("%Y%m%d") - merged_df = df1.merge(df2, on="level", suffixes=(f"_{date1_str}", f"_{date2_str}"), how="outer").fillna(0) - st.dataframe(merged_df, hide_index=True, width="stretch") - except Exception as e: - LOGGER.exception("日志对比失败", extra=LOG_EXTRA) - st.error(f"日志对比分析失败:{e}") - - return - - -def render_backtest_review() -> None: - """渲染回测执行、调参与结果复盘页面。""" - st.header("回测与复盘") - st.caption("1. 基于历史数据复盘当前策略;2. 借助强化学习/调参探索更优参数组合。") - app_cfg = get_config() - default_start, default_end = _default_backtest_range(window_days=60) - LOGGER.debug( - "回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s", - default_start, - default_end, - "000001.SZ", - 0.035, - -0.015, - 10, - extra=LOG_EXTRA, - ) - - st.markdown("### 回测参数") - col1, col2 = st.columns(2) - start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date") - end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date") - universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe") - col_target, col_stop, col_hold = st.columns(3) - target = col_target.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f", key="bt_target") - stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f", key="bt_stop") - hold_days = col_hold.number_input("持有期(交易日)", value=10, step=1, key="bt_hold_days") - LOGGER.debug( - "当前回测表单输入:start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s", - start_date, - end_date, - universe_text, - target, - stop, - hold_days, - extra=LOG_EXTRA, - ) - - tab_backtest, tab_rl = st.tabs(["回测验证", "强化学习调参"]) - - with tab_backtest: - st.markdown("#### 回测执行") - if st.button("运行回测", key="bt_run_button"): - LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA) - decision_log_container = st.container() - status_box = st.status("准备执行回测...", expanded=True) - llm_stats_placeholder = st.empty() - decision_entries: List[str] = [] - - def _decision_callback(ts_code: str, trade_dt: date, ctx: AgentContext, decision: Decision) -> None: - ts_label = trade_dt.isoformat() - summary = "" - for dept_decision in decision.department_decisions.values(): - if getattr(dept_decision, "summary", ""): - summary = str(dept_decision.summary) - break - entry_lines = [ - f"**{ts_label} {ts_code}** → {decision.action.value} (信心 {decision.confidence:.2f})", - ] - if summary: - entry_lines.append(f"摘要:{summary}") - dep_highlights = [] - for dept_code, dept_decision in decision.department_decisions.items(): - dep_highlights.append( - f"{dept_code}:{dept_decision.action.value}({dept_decision.confidence:.2f})" - ) - if dep_highlights: - entry_lines.append("部门意见:" + ";".join(dep_highlights)) - decision_entries.append(" \n".join(entry_lines)) - decision_log_container.markdown("\n\n".join(decision_entries[-200:])) - status_box.write(f"{ts_label} {ts_code} → {decision.action.value} (信心 {decision.confidence:.2f})") - stats = snapshot_llm_metrics() - llm_stats_placeholder.json( - { - "LLM 调用次数": stats.get("total_calls", 0), - "Prompt Tokens": stats.get("total_prompt_tokens", 0), - "Completion Tokens": stats.get("total_completion_tokens", 0), - "按 Provider": stats.get("provider_calls", {}), - "按模型": stats.get("model_calls", {}), - } - ) - _update_dashboard_sidebar(stats) - - reset_llm_metrics() - status_box.update(label="执行回测中...", state="running") - try: - universe = [code.strip() for code in universe_text.split(',') if code.strip()] - LOGGER.info( - "回测参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s", - start_date, - end_date, - universe, - target, - stop, - hold_days, - extra=LOG_EXTRA, - ) - backtest_cfg = BtConfig( - id="streamlit_demo", - name="Streamlit Demo Strategy", - start_date=start_date, - end_date=end_date, - universe=universe, - params={ - "target": target, - "stop": stop, - "hold_days": int(hold_days), - }, - ) - result = run_backtest(backtest_cfg, decision_callback=_decision_callback) - LOGGER.info( - "回测完成:nav_records=%s trades=%s", - len(result.nav_series), - len(result.trades), - extra=LOG_EXTRA, - ) - status_box.update(label="回测执行完成", state="complete") - st.success("回测执行完成,详见下方结果与统计。") - metrics = snapshot_llm_metrics() - llm_stats_placeholder.json( - { - "LLM 调用次数": metrics.get("total_calls", 0), - "Prompt Tokens": metrics.get("total_prompt_tokens", 0), - "Completion Tokens": metrics.get("total_completion_tokens", 0), - "按 Provider": metrics.get("provider_calls", {}), - "按模型": metrics.get("model_calls", {}), - } - ) - _update_dashboard_sidebar(metrics) - st.session_state["backtest_last_result"] = {"nav_records": result.nav_series, "trades": result.trades} - st.json(st.session_state["backtest_last_result"]) - except Exception as exc: # noqa: BLE001 - LOGGER.exception("回测执行失败", extra=LOG_EXTRA) - status_box.update(label="回测执行失败", state="error") - st.error(f"回测执行失败:{exc}") - - last_result = st.session_state.get("backtest_last_result") - if last_result: - st.markdown("#### 最近回测输出") - st.json(last_result) - - st.divider() - # ADD: Comparison view for multiple backtest configurations - with st.expander("历史回测结果对比", expanded=False): - st.caption("从历史回测配置中选择多个进行净值曲线与指标对比。") - normalize_to_one = st.checkbox("归一化到 1 起点", value=True, key="bt_cmp_normalize") - use_log_y = st.checkbox("对数坐标", value=False, key="bt_cmp_log_y") - metric_options = ["总收益", "最大回撤", "交易数", "平均换手", "风险事件"] - selected_metrics = st.multiselect("显示指标列", metric_options, default=metric_options, key="bt_cmp_metrics") - try: - with db_session(read_only=True) as conn: - cfg_rows = conn.execute( - "SELECT id, name FROM bt_config ORDER BY rowid DESC LIMIT 50" - ).fetchall() - except Exception: # noqa: BLE001 - LOGGER.exception("读取 bt_config 失败", extra=LOG_EXTRA) - cfg_rows = [] - cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows] - selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2], key="bt_cmp_configs") - selected_ids = [label.split(" | ")[0].strip() for label in selected_labels] - nav_df = pd.DataFrame() - rpt_df = pd.DataFrame() - if selected_ids: - try: - with db_session(read_only=True) as conn: - nav_df = pd.read_sql_query( - "SELECT cfg_id, trade_date, nav FROM bt_nav WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))), - conn, - params=tuple(selected_ids), - ) - rpt_df = pd.read_sql_query( - "SELECT cfg_id, summary FROM bt_report WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))), - conn, - params=tuple(selected_ids), - ) - except Exception: # noqa: BLE001 - LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA) - st.error("读取回测结果失败") - nav_df = pd.DataFrame() - rpt_df = pd.DataFrame() - if not nav_df.empty: - try: - nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce") - # ADD: date window filter - overall_min = pd.to_datetime(nav_df["trade_date"].min()).date() - overall_max = pd.to_datetime(nav_df["trade_date"].max()).date() - col_d1, col_d2 = st.columns(2) - start_filter = col_d1.date_input("起始日期", value=overall_min, key="bt_cmp_start") - end_filter = col_d2.date_input("结束日期", value=overall_max, key="bt_cmp_end") - if start_filter > end_filter: - start_filter, end_filter = end_filter, start_filter - mask = (nav_df["trade_date"].dt.date >= start_filter) & (nav_df["trade_date"].dt.date <= end_filter) - nav_df = nav_df.loc[mask] - pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav") - if normalize_to_one: - pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s) - import plotly.graph_objects as go - fig = go.Figure() - for col in pivot.columns: - fig.add_trace(go.Scatter(x=pivot.index, y=pivot[col], mode="lines", name=str(col))) - fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10)) - if use_log_y: - fig.update_yaxes(type="log") - st.plotly_chart(fig, width='stretch') - # ADD: export pivot - try: - csv_buf = pivot.reset_index() - csv_buf.columns = ["trade_date"] + [str(c) for c in pivot.columns] - st.download_button( - "下载曲线(CSV)", - data=csv_buf.to_csv(index=False), - file_name="bt_nav_compare.csv", - mime="text/csv", - key="dl_nav_compare", - ) - except Exception: - pass - except Exception: # noqa: BLE001 - LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA) - if not rpt_df.empty: - try: - metrics_rows: List[Dict[str, object]] = [] - for _, row in rpt_df.iterrows(): - cfg_id = row["cfg_id"] - try: - summary = json.loads(row["summary"]) if isinstance(row["summary"], str) else (row["summary"] or {}) - except json.JSONDecodeError: - summary = {} - record = { - "cfg_id": cfg_id, - "总收益": summary.get("total_return"), - "最大回撤": summary.get("max_drawdown"), - "交易数": summary.get("trade_count"), - "平均换手": summary.get("avg_turnover"), - "风险事件": summary.get("risk_events"), - } - metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)}) - if metrics_rows: - dfm = pd.DataFrame(metrics_rows) - st.dataframe(dfm, hide_index=True, width='stretch') - try: - st.download_button( - "下载指标(CSV)", - data=dfm.to_csv(index=False), - file_name="bt_metrics_compare.csv", - mime="text/csv", - key="dl_metrics_compare", - ) - except Exception: - pass - except Exception: # noqa: BLE001 - LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA) - else: - st.info("请选择至少一个配置进行对比。") - - - - with tab_rl: - st.caption("使用 DecisionEnv 对代理权重进行强化学习调参,支持单次与批量实验。") - with st.expander("离线调参实验 (DecisionEnv)", expanded=False): - st.caption( - "使用 DecisionEnv 对代理权重做离线调参。请选择需要优化的代理并设定权重范围," - "系统会运行一次回测并返回收益、回撤等指标。若 LLM 网络不可用,将返回失败标记。" - ) - - disable_departments = st.checkbox( - "禁用部门 LLM(仅规则代理,适合离线快速评估)", - value=True, - help="关闭部门调用后不依赖外部 LLM 网络,仅根据规则代理权重模拟。", - ) - - default_experiment_id = f"streamlit_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - experiment_id = st.text_input( - "实验 ID", - value=default_experiment_id, - help="用于在 tuning_results 表中区分不同实验。", - ) - strategy_label = st.text_input( - "策略说明", - value="DecisionEnv", - help="可选:为本次调参记录一个策略名称或备注。", - ) - - agent_objects = default_agents() - agent_names = [agent.name for agent in agent_objects] - if not agent_names: - st.info("暂无可调整的代理。") - else: - selected_agents = st.multiselect( - "选择调参的代理权重", - agent_names, - default=agent_names[:2], - key="decision_env_agents", - ) - - specs: List[ParameterSpec] = [] - action_values: List[float] = [] - range_valid = True - for idx, agent_name in enumerate(selected_agents): - col_min, col_max, col_action = st.columns([1, 1, 2]) - min_key = f"decision_env_min_{agent_name}" - max_key = f"decision_env_max_{agent_name}" - action_key = f"decision_env_action_{agent_name}" - default_min = 0.0 - default_max = 1.0 - min_val = col_min.number_input( - f"{agent_name} 最小权重", - min_value=0.0, - max_value=1.0, - value=default_min, - step=0.05, - key=min_key, - ) - max_val = col_max.number_input( - f"{agent_name} 最大权重", - min_value=0.0, - max_value=1.0, - value=default_max, - step=0.05, - key=max_key, - ) - if max_val <= min_val: - range_valid = False - action_val = col_action.slider( - f"{agent_name} 动作 (0-1)", - min_value=0.0, - max_value=1.0, - value=0.5, - step=0.01, - key=action_key, - ) - specs.append( - ParameterSpec( - name=f"weight_{agent_name}", - target=f"agent_weights.{agent_name}", - minimum=min_val, - maximum=max_val, - ) - ) - action_values.append(action_val) - - run_decision_env = st.button("执行单次调参", key="run_decision_env_button") - just_finished_single = False - if run_decision_env: - if not selected_agents: - st.warning("请至少选择一个代理进行调参。") - elif not range_valid: - st.error("请确保所有代理的最大权重大于最小权重。") - else: - LOGGER.info( - "离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s", - selected_agents, - action_values, - disable_departments, - extra=LOG_EXTRA, - ) - baseline_weights = cfg.agent_weights.as_dict() - for agent in agent_objects: - baseline_weights.setdefault(agent.name, 1.0) - - universe_env = [code.strip() for code in universe_text.split(',') if code.strip()] - if not universe_env: - st.error("请先指定至少一个股票代码。") - else: - bt_cfg_env = BtConfig( - id="decision_env_streamlit", - name="DecisionEnv Streamlit", - start_date=start_date, - end_date=end_date, - universe=universe_env, - params={ - "target": target, - "stop": stop, - "hold_days": int(hold_days), - }, - method=cfg.decision_method, - ) - env = DecisionEnv( - bt_config=bt_cfg_env, - parameter_specs=specs, - baseline_weights=baseline_weights, - disable_departments=disable_departments, - ) - env.reset() - LOGGER.debug( - "离线调参(单次)启动 DecisionEnv:cfg=%s 参数维度=%s", - bt_cfg_env, - len(specs), - extra=LOG_EXTRA, - ) - with st.spinner("正在执行离线调参……"): - try: - observation, reward, done, info = env.step(action_values) - LOGGER.info( - "离线调参(单次)完成,obs=%s reward=%.4f done=%s", - observation, - reward, - done, - extra=LOG_EXTRA, - ) - except Exception as exc: # noqa: BLE001 - LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA) - st.error(f"离线调参失败:{exc}") - st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None) - else: - if observation.get("failure"): - st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。") - st.json(observation) - st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None) - else: - resolved_experiment_id = experiment_id or str(uuid.uuid4()) - resolved_strategy = strategy_label or "DecisionEnv" - action_payload = { - name: value - for name, value in zip(selected_agents, action_values) - } - metrics_payload = dict(observation) - metrics_payload["reward"] = reward - log_success = False - try: - log_tuning_result( - experiment_id=resolved_experiment_id, - strategy=resolved_strategy, - action=action_payload, - reward=reward, - metrics=metrics_payload, - weights=info.get("weights", {}), - ) - except Exception: # noqa: BLE001 - LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA) - else: - log_success = True - LOGGER.info( - "离线调参(单次)日志写入成功:experiment=%s strategy=%s", - resolved_experiment_id, - resolved_strategy, - extra=LOG_EXTRA, - ) - st.session_state[_DECISION_ENV_SINGLE_RESULT_KEY] = { - "observation": dict(observation), - "reward": float(reward), - "weights": info.get("weights", {}), - "nav_series": info.get("nav_series"), - "trades": info.get("trades"), - "portfolio_snapshots": info.get("portfolio_snapshots"), - "portfolio_trades": info.get("portfolio_trades"), - "risk_breakdown": info.get("risk_breakdown"), - "selected_agents": list(selected_agents), - "action_values": list(action_values), - "experiment_id": resolved_experiment_id, - "strategy_label": resolved_strategy, - "logged": log_success, - } - just_finished_single = True - single_result = st.session_state.get(_DECISION_ENV_SINGLE_RESULT_KEY) - if single_result: - if just_finished_single: - st.success("离线调参完成") - else: - st.success("离线调参结果(最近一次运行)") - st.caption( - f"实验 ID:{single_result.get('experiment_id', '-') } | 策略:{single_result.get('strategy_label', 'DecisionEnv')}" - ) - observation = single_result.get("observation", {}) - reward = float(single_result.get("reward", 0.0)) - col_metrics = st.columns(4) - col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}") - col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}") - col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}") - col_metrics[3].metric("奖励", f"{reward:+.4f}") - - turnover_ratio = float(observation.get("turnover", 0.0) or 0.0) - turnover_value = float(observation.get("turnover_value", 0.0) or 0.0) - risk_count = float(observation.get("risk_count", 0.0) or 0.0) - col_metrics_extra = st.columns(3) - col_metrics_extra[0].metric("平均换手率", f"{turnover_ratio:.2%}") - col_metrics_extra[1].metric("成交额", f"{turnover_value:,.0f}") - col_metrics_extra[2].metric("风险事件数", f"{int(risk_count)}") - - weights_dict = single_result.get("weights") or {} - if weights_dict: - st.write("调参后权重:") - st.json(weights_dict) - if st.button("保存这些权重为默认配置", key="save_decision_env_weights_single"): - try: - cfg.agent_weights.update_from_dict(weights_dict) - save_config(cfg) - except Exception as exc: # noqa: BLE001 - LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)}) - st.error(f"写入配置失败:{exc}") - else: - st.success("代理权重已写入 config.json") - - if single_result.get("logged"): - st.caption("调参结果已写入 tuning_results 表。") - - nav_series = single_result.get("nav_series") or [] - if nav_series: - try: - nav_df = pd.DataFrame(nav_series) - if {"trade_date", "nav"}.issubset(nav_df.columns): - nav_df = nav_df.sort_values("trade_date") - nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"]) - st.line_chart(nav_df.set_index("trade_date")["nav"], height=220) - except Exception: # noqa: BLE001 - LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA) - - trades = single_result.get("trades") or [] - if trades: - st.write("成交记录:") - st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch') - - snapshots = single_result.get("portfolio_snapshots") or [] - if snapshots: - with st.expander("投资组合快照", expanded=False): - st.dataframe(pd.DataFrame(snapshots), hide_index=True, width='stretch') - - portfolio_trades = single_result.get("portfolio_trades") or [] - if portfolio_trades: - with st.expander("组合成交明细", expanded=False): - st.dataframe(pd.DataFrame(portfolio_trades), hide_index=True, width='stretch') - - risk_breakdown = single_result.get("risk_breakdown") or {} - if risk_breakdown: - with st.expander("风险事件统计", expanded=False): - st.json(risk_breakdown) - - if st.button("清除单次调参结果", key="clear_decision_env_single"): - st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None) - st.success("已清除单次调参结果缓存。") - - st.divider() - st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。") - default_grid = "\n".join( - [ - ",".join(["0.2" for _ in specs]), - ",".join(["0.5" for _ in specs]), - ",".join(["0.8" for _ in specs]), - ] - ) if specs else "" - action_grid_raw = st.text_area( - "动作列表", - value=default_grid, - height=120, - key="decision_env_batch_actions", - ) - run_batch = st.button("批量执行调参", key="run_decision_env_batch") - batch_just_ran = False - if run_batch: - if not selected_agents: - st.warning("请先选择调参代理。") - elif not range_valid: - st.error("请确保所有代理的最大权重大于最小权重。") - else: - LOGGER.info( - "离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s", - selected_agents, - disable_departments, - extra=LOG_EXTRA, - ) - lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()] - if not lines: - st.warning("请在文本框中输入至少一组动作。") - else: - LOGGER.debug( - "离线调参(批量)原始输入=%s", - lines, - extra=LOG_EXTRA, - ) - parsed_actions: List[List[float]] = [] - for line in lines: - try: - values = [float(val.strip()) for val in line.split(',') if val.strip()] - except ValueError: - st.error(f"无法解析动作行:{line}") - parsed_actions = [] - break - if len(values) != len(specs): - st.error(f"动作维度不匹配(期望 {len(specs)} 个值):{line}") - parsed_actions = [] - break - parsed_actions.append(values) - if parsed_actions: - LOGGER.info( - "离线调参(批量)解析动作成功,数量=%s", - len(parsed_actions), - extra=LOG_EXTRA, - ) - baseline_weights = cfg.agent_weights.as_dict() - for agent in agent_objects: - baseline_weights.setdefault(agent.name, 1.0) - - universe_env = [code.strip() for code in universe_text.split(',') if code.strip()] - if not universe_env: - st.error("请先指定至少一个股票代码。") - else: - bt_cfg_env = BtConfig( - id="decision_env_streamlit_batch", - name="DecisionEnv Batch", - start_date=start_date, - end_date=end_date, - universe=universe_env, - params={ - "target": target, - "stop": stop, - "hold_days": int(hold_days), - }, - method=cfg.decision_method, - ) - env = DecisionEnv( - bt_config=bt_cfg_env, - parameter_specs=specs, - baseline_weights=baseline_weights, - disable_departments=disable_departments, - ) - results: List[Dict[str, object]] = [] - resolved_experiment_id = experiment_id or str(uuid.uuid4()) - resolved_strategy = strategy_label or "DecisionEnv" - LOGGER.debug( - "离线调参(批量)启动 DecisionEnv:cfg=%s 动作组=%s", - bt_cfg_env, - len(parsed_actions), - extra=LOG_EXTRA, - ) - with st.spinner("正在批量执行调参……"): - for idx, action_vals in enumerate(parsed_actions, start=1): - env.reset() - try: - observation, reward, done, info = env.step(action_vals) - except Exception as exc: # noqa: BLE001 - LOGGER.exception("批量调参失败", extra=LOG_EXTRA) - results.append( - { - "序号": idx, - "动作": action_vals, - "状态": "error", - "错误": str(exc), - } - ) - continue - if observation.get("failure"): - results.append( - { - "序号": idx, - "动作": action_vals, - "状态": "failure", - "奖励": -1.0, - } - ) - else: - LOGGER.info( - "离线调参(批量)第 %s 组完成,reward=%.4f obs=%s", - idx, - reward, - observation, - extra=LOG_EXTRA, - ) - action_payload = { - name: value - for name, value in zip(selected_agents, action_vals) - } - metrics_payload = dict(observation) - metrics_payload["reward"] = reward - weights_payload = info.get("weights", {}) - try: - log_tuning_result( - experiment_id=resolved_experiment_id, - strategy=resolved_strategy, - action=action_payload, - reward=reward, - metrics=metrics_payload, - weights=weights_payload, - ) - except Exception: # noqa: BLE001 - LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA) - results.append( - { - "序号": idx, - "动作": action_vals, - "状态": "ok", - "总收益": observation.get("total_return", 0.0), - "最大回撤": observation.get("max_drawdown", 0.0), - "波动率": observation.get("volatility", 0.0), - "奖励": reward, - "权重": weights_payload, - } - ) - st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = { - "results": results, - "selectable": [ - row - for row in results - if row.get("状态") == "ok" and row.get("权重") - ], - "experiment_id": resolved_experiment_id, - "strategy_label": resolved_strategy, - } - batch_just_ran = True - LOGGER.info( - "离线调参(批量)执行结束,总结果条数=%s", - len(results), - extra=LOG_EXTRA, - ) - batch_state = st.session_state.get(_DECISION_ENV_BATCH_RESULTS_KEY) - if batch_state: - results = batch_state.get("results") or [] - if results: - if batch_just_ran: - st.success("批量调参完成") - else: - st.success("批量调参结果(最近一次运行)") - st.caption( - f"实验 ID:{batch_state.get('experiment_id', '-') } | 策略:{batch_state.get('strategy_label', 'DecisionEnv')}" - ) - results_df = pd.DataFrame(results) - st.write("批量调参结果:") - st.dataframe(results_df, hide_index=True, width='stretch') - selectable = batch_state.get("selectable") or [] - if selectable: - option_labels = [ - f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}" - for row in selectable - ] - selected_label = st.selectbox( - "选择要保存的记录", - option_labels, - key="decision_env_batch_select", - ) - selected_row = None - for label, row in zip(option_labels, selectable): - if label == selected_label: - selected_row = row - break - if selected_row and st.button( - "保存所选权重为默认配置", - key="save_decision_env_weights_batch", - ): - try: - cfg.agent_weights.update_from_dict(selected_row.get("权重", {})) - save_config(cfg) - except Exception as exc: # noqa: BLE001 - LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)}) - st.error(f"写入配置失败:{exc}") - else: - st.success( - f"已将序号 {selected_row['序号']} 的权重写入 config.json" - ) - else: - st.caption("暂无成功的结果可供保存。") - else: - st.caption("批量调参在最近一次执行中未产生结果。") - if st.button("清除批量调参结果", key="clear_decision_env_batch"): - st.session_state.pop(_DECISION_ENV_BATCH_RESULTS_KEY, None) - st.session_state.pop("decision_env_batch_select", None) - st.success("已清除批量调参结果缓存。") - - -def render_config_overview() -> None: - """Render a concise overview of persisted configuration values.""" - - LOGGER.info("渲染配置概览页", extra=LOG_EXTRA) - cfg = get_config() - - st.subheader("核心配置概览") - col1, col2, col3 = st.columns(3) - col1.metric("决策方式", cfg.decision_method.upper()) - col2.metric("自动更新数据", "启用" if cfg.auto_update_data else "关闭") - col3.metric("数据更新间隔(天)", cfg.data_update_interval) - - col4, col5, col6 = st.columns(3) - col4.metric("强制刷新", "开启" if cfg.force_refresh else "关闭") - col5.metric("TuShare Token", "已配置" if cfg.tushare_token else "未配置") - col6.metric("配置文件", cfg.data_paths.config_file.name) - st.caption(f"配置文件路径:{cfg.data_paths.config_file}") - - st.divider() - st.subheader("RSS 数据源状态") - rss_sources = cfg.rss_sources or {} - if rss_sources: - rows: List[Dict[str, object]] = [] - for name, payload in rss_sources.items(): - if isinstance(payload, dict): - rows.append( - { - "名称": name, - "启用": "是" if payload.get("enabled", True) else "否", - "URL": payload.get("url", "-"), - "关键词数": len(payload.get("keywords", []) or []), - } - ) - elif isinstance(payload, bool): - rows.append( - { - "名称": name, - "启用": "是" if payload else "否", - "URL": "-", - "关键词数": 0, - } - ) - if rows: - st.dataframe(pd.DataFrame(rows), hide_index=True, width="stretch") - else: - st.info("未配置 RSS 数据源。") - else: - st.info("未在配置文件中找到 RSS 数据源。") - - st.divider() - st.subheader("部门配置") - dept_rows: List[Dict[str, object]] = [] - for code, dept in cfg.departments.items(): - dept_rows.append( - { - "部门": dept.title or code, - "代码": code, - "权重": dept.weight, - "LLM 策略": dept.llm.strategy, - "模板": dept.prompt_template_id or f"{code}_dept", - "模板版本": dept.prompt_template_version or "(激活版本)", - } - ) - if dept_rows: - st.dataframe(pd.DataFrame(dept_rows), hide_index=True, width="stretch") - else: - st.info("尚未配置任何部门。") - - st.divider() - st.subheader("LLM 成本控制") - cost = cfg.llm_cost - col_a, col_b, col_c, col_d = st.columns(4) - col_a.metric("成本控制", "启用" if cost.enabled else "关闭") - col_b.metric("小时预算($)", f"{cost.hourly_budget:.2f}") - col_c.metric("日预算($)", f"{cost.daily_budget:.2f}") - col_d.metric("月预算($)", f"{cost.monthly_budget:.2f}") - - if cost.model_weights: - weight_rows = ( - pd.DataFrame( - [ - {"模型": model, "占比上限": f"{limit * 100:.0f}%"} - for model, limit in cost.model_weights.items() - ] - ) - ) - st.dataframe(weight_rows, hide_index=True, width="stretch") - else: - st.caption("未配置模型占比限制。") - - st.divider() - st.caption("提示:数据源、LLM 及投资组合设置可在对应标签页中调整。") - - -def render_llm_settings() -> None: - cfg = get_config() - st.subheader("LLM 设置") - providers = cfg.llm_providers - provider_keys = sorted(providers.keys()) - st.caption("先在 Provider 中维护基础连接(URL、Key、模型),再为全局与各部门设置个性化参数。") - - provider_select_col, provider_manage_col = st.columns([3, 1]) - if provider_keys: - try: - default_provider = cfg.llm.primary.provider or provider_keys[0] - provider_index = provider_keys.index(default_provider) - except ValueError: - provider_index = 0 - selected_provider = provider_select_col.selectbox( - "选择 Provider", - provider_keys, - index=provider_index, - key="llm_provider_select", - ) - else: - selected_provider = None - provider_select_col.info("尚未配置 Provider,请先创建。") - - new_provider_name = provider_manage_col.text_input("新增 Provider", key="new_provider_name") - if provider_manage_col.button("创建 Provider", key="create_provider_btn"): - key = (new_provider_name or "").strip().lower() - if not key: - st.warning("请输入有效的 Provider 名称。") - elif key in providers: - st.warning(f"Provider {key} 已存在。") - else: - providers[key] = LLMProvider(key=key) - cfg.llm_providers = providers - save_config() - st.success(f"已创建 Provider {key}。") - st.rerun() - - if selected_provider: - provider_cfg = providers.get(selected_provider, LLMProvider(key=selected_provider)) - title_key = f"provider_title_{selected_provider}" - base_key = f"provider_base_{selected_provider}" - api_key_key = f"provider_api_{selected_provider}" - mode_key = f"provider_mode_{selected_provider}" - enabled_key = f"provider_enabled_{selected_provider}" - - title_val = st.text_input("备注名称", value=provider_cfg.title or "", key=title_key) - base_val = st.text_input("Base URL", value=provider_cfg.base_url or "", key=base_key, help="调用地址,例如:https://api.openai.com") - api_val = st.text_input("API Key", value=provider_cfg.api_key or "", key=api_key_key, type="password") - - enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key) - mode_val = st.selectbox("模式", options=["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key) - st.markdown("可用模型:") - if provider_cfg.models: - st.code("\n".join(provider_cfg.models), language="text") - else: - st.info("尚未获取模型列表,可点击下方按钮自动拉取。") - try: - cache_key = f"{selected_provider}|{(base_val or '').strip()}" - entry = _MODEL_CACHE.get(cache_key) - if entry and isinstance(entry.get("ts"), float): - ts = datetime.fromtimestamp(entry["ts"]).strftime("%Y-%m-%d %H:%M:%S") - st.caption(f"最近拉取时间:{ts}") - except Exception: - pass - - fetch_key = f"fetch_models_{selected_provider}" - if st.button("获取模型列表", key=fetch_key): - with st.spinner("正在获取模型列表..."): - models, error = _discover_provider_models(provider_cfg, base_val, api_val) - if error: - st.error(error) - else: - provider_cfg.models = models - if models and (not provider_cfg.default_model or provider_cfg.default_model not in models): - provider_cfg.default_model = models[0] - providers[selected_provider] = provider_cfg - cfg.llm_providers = providers - cfg.sync_runtime_llm() - save_config() - st.success(f"共获取 {len(models)} 个模型。") - st.rerun() - - if st.button("保存 Provider", key=f"save_provider_{selected_provider}"): - provider_cfg.title = title_val.strip() - provider_cfg.base_url = base_val.strip() - provider_cfg.api_key = api_val.strip() or None - provider_cfg.enabled = enabled_val - provider_cfg.mode = mode_val - providers[selected_provider] = provider_cfg - cfg.llm_providers = providers - cfg.sync_runtime_llm() - save_config() - st.success("Provider 已保存。") - st.session_state[title_key] = provider_cfg.title or "" - - provider_in_use = (cfg.llm.primary.provider == selected_provider) or any( - ep.provider == selected_provider for ep in cfg.llm.ensemble - ) - if not provider_in_use: - for dept in cfg.departments.values(): - if dept.llm.primary.provider == selected_provider or any(ep.provider == selected_provider for ep in dept.llm.ensemble): - provider_in_use = True - break - if st.button( - "删除 Provider", - key=f"delete_provider_{selected_provider}", - disabled=provider_in_use or len(providers) <= 1, - ): - providers.pop(selected_provider, None) - cfg.llm_providers = providers - cfg.sync_runtime_llm() - save_config() - st.success("Provider 已删除。") - st.rerun() - - st.markdown("##### 全局推理配置") - if not provider_keys: - st.warning("请先配置至少一个 Provider。") - else: - global_cfg = cfg.llm - primary = global_cfg.primary - try: - provider_index = provider_keys.index(primary.provider or provider_keys[0]) - except ValueError: - provider_index = 0 - selected_global_provider = st.selectbox( - "主模型 Provider", - provider_keys, - index=provider_index, - key="global_provider_select", - ) - provider_cfg = providers.get(selected_global_provider) - available_models = provider_cfg.models if provider_cfg else [] - default_model = primary.model or (provider_cfg.default_model if provider_cfg else None) - if available_models: - options = available_models + ["自定义"] - try: - model_index = available_models.index(default_model) - model_choice = st.selectbox("主模型", options, index=model_index, key="global_model_choice") - except ValueError: - model_choice = st.selectbox("主模型", options, index=len(options) - 1, key="global_model_choice") - if model_choice == "自定义": - model_val = st.text_input("自定义模型", value=default_model or "", key="global_model_custom").strip() - else: - model_val = model_choice - else: - model_val = st.text_input("主模型", value=default_model or "", key="global_model_custom").strip() - - temp_default = primary.temperature if primary.temperature is not None else (provider_cfg.default_temperature if provider_cfg else 0.2) - temp_val = st.slider("主模型温度", min_value=0.0, max_value=2.0, value=float(temp_default), step=0.05, key="global_temp") - timeout_default = primary.timeout if primary.timeout is not None else (provider_cfg.default_timeout if provider_cfg else 30.0) - timeout_val = st.number_input("主模型超时(秒)", min_value=5, max_value=300, value=int(timeout_default), step=5, key="global_timeout") - prompt_template_val = st.text_area( - "主模型 Prompt 模板(可选)", - value=primary.prompt_template or provider_cfg.prompt_template if provider_cfg else "", - height=120, - key="global_prompt_template", - ) - - strategy_val = st.selectbox("推理策略", sorted(ALLOWED_LLM_STRATEGIES), index=sorted(ALLOWED_LLM_STRATEGIES).index(global_cfg.strategy) if global_cfg.strategy in ALLOWED_LLM_STRATEGIES else 0, key="global_strategy") - show_ensemble = strategy_val != "single" - majority_threshold_val = st.number_input( - "多数投票门槛", - min_value=1, - max_value=10, - value=int(global_cfg.majority_threshold), - step=1, - key="global_majority", - disabled=not show_ensemble, - ) - if not show_ensemble: - majority_threshold_val = 1 - - ensemble_rows: List[Dict[str, str]] = [] - if show_ensemble: - ensemble_rows = [ - { - "provider": ep.provider, - "model": ep.model or "", - "temperature": "" if ep.temperature is None else f"{ep.temperature:.3f}", - "timeout": "" if ep.timeout is None else str(int(ep.timeout)), - "prompt_template": ep.prompt_template or "", - } - for ep in global_cfg.ensemble - ] or [{"provider": primary.provider or selected_global_provider, "model": "", "temperature": "", "timeout": "", "prompt_template": ""}] - - ensemble_editor = st.data_editor( - ensemble_rows, - num_rows="dynamic", - key="global_ensemble_editor", - width='stretch', - hide_index=True, - column_config={ - "provider": st.column_config.SelectboxColumn("Provider", options=provider_keys), - "model": st.column_config.TextColumn("模型"), - "temperature": st.column_config.TextColumn("温度"), - "timeout": st.column_config.TextColumn("超时(秒)"), - "prompt_template": st.column_config.TextColumn("Prompt 模板"), - }, - ) - if hasattr(ensemble_editor, "to_dict"): - ensemble_rows = ensemble_editor.to_dict("records") - else: - ensemble_rows = ensemble_editor - else: - st.info("当前策略为单模型,未启用协作模型。") - - if st.button("保存全局配置", key="save_global_llm"): - primary.provider = selected_global_provider - primary.model = model_val or None - primary.temperature = float(temp_val) - primary.timeout = float(timeout_val) - primary.prompt_template = prompt_template_val.strip() or None - primary.base_url = None - primary.api_key = None - - new_ensemble: List[LLMEndpoint] = [] - if show_ensemble: - for row in ensemble_rows: - provider_val = (row.get("provider") or "").strip().lower() - if not provider_val: - continue - model_raw = (row.get("model") or "").strip() or None - temp_raw = (row.get("temperature") or "").strip() - timeout_raw = (row.get("timeout") or "").strip() - prompt_raw = (row.get("prompt_template") or "").strip() - new_ensemble.append( - LLMEndpoint( - provider=provider_val, - model=model_raw, - temperature=float(temp_raw) if temp_raw else None, - timeout=float(timeout_raw) if timeout_raw else None, - prompt_template=prompt_raw or None, - ) - ) - cfg.llm.ensemble = new_ensemble - cfg.llm.strategy = strategy_val - cfg.llm.majority_threshold = int(majority_threshold_val) - cfg.sync_runtime_llm() - save_config() - st.success("全局 LLM 配置已保存。") - st.json(llm_config_snapshot()) - - st.markdown("##### 部门配置") - dept_settings = cfg.departments or {} - dept_rows = [ - { - "code": code, - "title": dept.title, - "description": dept.description, - "weight": float(dept.weight), - "strategy": dept.llm.strategy, - "majority_threshold": dept.llm.majority_threshold, - "provider": dept.llm.primary.provider or (provider_keys[0] if provider_keys else ""), - "model": dept.llm.primary.model or "", - "temperature": "" if dept.llm.primary.temperature is None else f"{dept.llm.primary.temperature:.3f}", - "timeout": "" if dept.llm.primary.timeout is None else str(int(dept.llm.primary.timeout)), - "prompt_template": dept.llm.primary.prompt_template or "", - } - for code, dept in sorted(dept_settings.items()) - ] - - if not dept_rows: - st.info("当前未配置部门,可在 config.json 中添加。") - dept_rows = [] - - dept_editor = st.data_editor( - dept_rows, - num_rows="fixed", - key="department_editor", - width='stretch', - hide_index=True, - column_config={ - "code": st.column_config.TextColumn("编码", disabled=True), - "title": st.column_config.TextColumn("名称"), - "description": st.column_config.TextColumn("说明"), - "weight": st.column_config.NumberColumn("权重", min_value=0.0, max_value=10.0, step=0.1), - "strategy": st.column_config.SelectboxColumn("策略", options=sorted(ALLOWED_LLM_STRATEGIES)), - "majority_threshold": st.column_config.NumberColumn("投票阈值", min_value=1, max_value=10, step=1), - "provider": st.column_config.SelectboxColumn("Provider", options=provider_keys or [""]), - "model": st.column_config.TextColumn("模型"), - "temperature": st.column_config.TextColumn("温度"), - "timeout": st.column_config.TextColumn("超时(秒)"), - "prompt_template": st.column_config.TextColumn("Prompt 模板"), - }, - ) - - if hasattr(dept_editor, "to_dict"): - dept_rows = dept_editor.to_dict("records") - else: - dept_rows = dept_editor - - col_reset, col_save = st.columns([1, 1]) - - if col_save.button("保存部门配置"): - updated_departments: Dict[str, DepartmentSettings] = {} - for row in dept_rows: - code = row.get("code") - if not code: - continue - existing = dept_settings.get(code) or DepartmentSettings(code=code, title=code) - existing.title = row.get("title") or existing.title - existing.description = row.get("description") or "" - try: - existing.weight = max(0.0, float(row.get("weight", existing.weight))) - except (TypeError, ValueError): - pass - - strategy_val = (row.get("strategy") or existing.llm.strategy).lower() - if strategy_val in ALLOWED_LLM_STRATEGIES: - existing.llm.strategy = strategy_val - if existing.llm.strategy == "single": - existing.llm.majority_threshold = 1 - existing.llm.ensemble = [] - else: - majority_raw = row.get("majority_threshold") - try: - majority_val = int(majority_raw) - if majority_val > 0: - existing.llm.majority_threshold = majority_val - except (TypeError, ValueError): - pass - - provider_val = (row.get("provider") or existing.llm.primary.provider or (provider_keys[0] if provider_keys else "ollama")).strip().lower() - model_val = (row.get("model") or "").strip() or None - temp_raw = (row.get("temperature") or "").strip() - timeout_raw = (row.get("timeout") or "").strip() - prompt_raw = (row.get("prompt_template") or "").strip() - - endpoint = existing.llm.primary or LLMEndpoint() - endpoint.provider = provider_val - endpoint.model = model_val - endpoint.temperature = float(temp_raw) if temp_raw else None - endpoint.timeout = float(timeout_raw) if timeout_raw else None - endpoint.prompt_template = prompt_raw or None - endpoint.base_url = None - endpoint.api_key = None - existing.llm.primary = endpoint - if existing.llm.strategy != "single": - existing.llm.ensemble = [] - - updated_departments[code] = existing - - if updated_departments: - cfg.departments = updated_departments - cfg.sync_runtime_llm() - save_config() - st.success("部门配置已更新。") - else: - st.warning("未能解析部门配置输入。") - - if col_reset.button("恢复默认部门"): - from app.utils.config import _default_departments - - cfg.departments = _default_departments() - cfg.sync_runtime_llm() - save_config() - st.success("已恢复默认部门配置。") - st.rerun() - - st.caption("部门配置存储为独立 LLM 参数,执行时会自动套用对应 Provider 的连接信息。") - -def render_market_visualization() -> None: - """Render a standalone market data visualization dashboard.""" - - st.header("股票行情可视化") - options = _load_stock_options() - default_code = options[0] if options else "000001.SZ" - - if options: - selection = st.selectbox("选择股票", options, index=0) - ts_code = _parse_ts_code(selection) - LOGGER.debug("选择股票:%s", ts_code, extra=LOG_EXTRA) - else: - ts_code = st.text_input("输入股票代码(如 000001.SZ)", value=default_code).strip().upper() - LOGGER.debug("输入股票:%s", ts_code, extra=LOG_EXTRA) - - col_start, col_end = st.columns(2) - default_start = date.today() - timedelta(days=180) - start_date = col_start.date_input("开始日期", value=default_start, key="viz_start") - end_date = col_end.date_input("结束日期", value=date.today(), key="viz_end") - LOGGER.debug("行情可视化日期范围:%s-%s", start_date, end_date, extra=LOG_EXTRA) - - if start_date > end_date: - LOGGER.warning("无效日期范围:%s>%s", start_date, end_date, extra=LOG_EXTRA) - st.error("开始日期不能晚于结束日期") - return - - with st.spinner("正在加载行情数据..."): - try: - df = _load_daily_frame(ts_code, start_date, end_date) - except Exception as exc: # noqa: BLE001 - LOGGER.exception("加载行情数据失败", extra=LOG_EXTRA) - st.error(f"读取数据失败:{exc}") - return - - if df.empty: - LOGGER.warning("指定区间无行情数据:%s %s-%s", ts_code, start_date, end_date, extra=LOG_EXTRA) - st.warning("未查询到该区间的交易数据,请确认数据库已拉取对应日线。") - return - - price_df = df[["close"]].rename(columns={"close": "收盘价"}) - volume_df = df[["vol"]].rename(columns={"vol": "成交量(手)"}) - - sampled = price_df.resample("3D").last().dropna() if price_df.shape[0] > 180 else price_df - volume_sampled = volume_df.resample("3D").mean().dropna() if volume_df.shape[0] > 180 else volume_df - - first_close = sampled.iloc[0, 0] - last_close = sampled.iloc[-1, 0] - delta_abs = last_close - first_close - delta_pct = (delta_abs / first_close * 100) if first_close else 0.0 - - metric_col1, metric_col2, metric_col3 = st.columns(3) - metric_col1.metric("最新收盘价", f"{last_close:.2f}", delta=f"{delta_abs:+.2f}") - metric_col2.metric("区间涨跌幅", f"{delta_pct:+.2f}%") - metric_col3.metric("平均成交量", f"{volume_sampled['成交量(手)'].mean():.0f}") - - df_reset = df.reset_index().rename(columns={ - "trade_date": "交易日", - "open": "开盘价", - "high": "最高价", - "low": "最低价", - "close": "收盘价", - "vol": "成交量(手)", - "amount": "成交额(千元)", - }) - df_reset["成交额(千元)"] = df_reset["成交额(千元)"] / 1000 - - numeric_columns = ["开盘价", "最高价", "最低价", "收盘价", "成交量(手)", "成交额(千元)"] - for col in numeric_columns: - if col in df_reset.columns: - df_reset[col] = pd.to_numeric(df_reset[col], errors="coerce") - df_reset["交易日"] = pd.to_datetime(df_reset["交易日"]) - - candle_fig = go.Figure( - data=[ - go.Candlestick( - x=df_reset["交易日"], - open=df_reset["开盘价"], - high=df_reset["最高价"], - low=df_reset["最低价"], - close=df_reset["收盘价"], - name="K线", - ) - ] - ) - candle_fig.update_layout(height=420, margin=dict(l=10, r=10, t=40, b=10)) - st.plotly_chart(candle_fig, width="stretch") - - vol_fig = px.bar( - df_reset, - x="交易日", - y="成交量(手)", - labels={"成交量(手)": "成交量(手)"}, - title="成交量", - ) - vol_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10)) - st.plotly_chart(vol_fig, width="stretch") - - amt_fig = px.bar( - df_reset, - x="交易日", - y="成交额(千元)", - labels={"成交额(千元)": "成交额(千元)"}, - title="成交额", - ) - amt_fig.update_layout(height=280, margin=dict(l=10, r=10, t=40, b=10)) - st.plotly_chart(amt_fig, width="stretch") - - df_reset["月份"] = df_reset["交易日"].dt.to_period("M").astype(str) - df_reset["收盘价"] = pd.to_numeric(df_reset["收盘价"], errors="coerce") - box_fig = px.box( - df_reset, - x="月份", - y="收盘价", - points="outliers", - title="月度收盘价分布", - ) - box_fig.update_layout(height=320, margin=dict(l=10, r=10, t=40, b=10)) - st.plotly_chart(box_fig, width="stretch") - - st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。") - st.dataframe(df_reset.tail(20), width="stretch") - LOGGER.info("行情可视化完成,展示行数=%s", len(df_reset), extra=LOG_EXTRA) - - -def render_tests() -> None: - LOGGER.info("渲染自检页面", extra=LOG_EXTRA) - st.header("自检测试") - st.write("用于快速检查数据库与数据拉取是否正常工作。") - - if st.button("测试数据库初始化"): - LOGGER.info("点击测试数据库初始化按钮", extra=LOG_EXTRA) - with st.spinner("正在检查数据库..."): - result = initialize_database() - if result.skipped: - LOGGER.info("数据库已存在,无需初始化", extra=LOG_EXTRA) - st.success("数据库已存在,检查通过。") - else: - LOGGER.info("数据库初始化完成,执行语句数=%s", result.executed, extra=LOG_EXTRA) - st.success(f"数据库初始化完成,共执行 {result.executed} 条语句。") - - st.divider() - - if st.button("测试 TuShare 拉取(示例 2024-01-01 至 2024-01-03)"): - LOGGER.info("点击示例 TuShare 拉取按钮", extra=LOG_EXTRA) - with st.spinner("正在调用 TuShare 接口..."): - try: - run_ingestion( - FetchJob( - name="streamlit_self_test", - start=date(2024, 1, 1), - end=date(2024, 1, 3), - ts_codes=("000001.SZ",), - ), - include_limits=False, - ) - LOGGER.info("示例 TuShare 拉取成功", extra=LOG_EXTRA) - st.success("TuShare 示例拉取完成,数据已写入数据库。") - except Exception as exc: # noqa: BLE001 - LOGGER.exception("示例 TuShare 拉取失败", extra=LOG_EXTRA) - st.error(f"拉取失败:{exc}") - alerts.add_warning("TuShare", "示例拉取失败", str(exc)) - _update_dashboard_sidebar() - - st.info("注意:TuShare 拉取依赖网络与 Token,若环境未配置将出现错误提示。") - - st.divider() - - st.subheader("RSS 数据测试") - st.write("用于验证 RSS 配置是否能够正常抓取新闻并写入数据库。") - rss_url = st.text_input( - "测试 RSS 地址", - value="https://rsshub.app/cls/depth/1000", - help="留空则使用默认配置的全部 RSS 来源。", - ).strip() - rss_hours = int( - st.number_input( - "回溯窗口(小时)", - min_value=1, - max_value=168, - value=24, - step=6, - help="仅抓取最近指定小时内的新闻。", - ) - ) - rss_limit = int( - st.number_input( - "单源抓取条数", - min_value=1, - max_value=200, - value=50, - step=10, - ) - ) - if st.button("运行 RSS 测试"): - from app.ingest import rss as rss_ingest - - LOGGER.info( - "点击 RSS 测试按钮 rss_url=%s hours=%s limit=%s", - rss_url, - rss_hours, - rss_limit, - extra=LOG_EXTRA, - ) - with st.spinner("正在抓取 RSS 新闻..."): - try: - if rss_url: - items = rss_ingest.fetch_rss_feed( - rss_url, - hours_back=rss_hours, - max_items=rss_limit, - ) - count = rss_ingest.save_news_items(items) - else: - count = rss_ingest.ingest_configured_rss( - hours_back=rss_hours, - max_items_per_feed=rss_limit, - ) - st.success(f"RSS 测试完成,新增 {count} 条新闻记录。") - except Exception as exc: # noqa: BLE001 - LOGGER.exception("RSS 测试失败", extra=LOG_EXTRA) - st.error(f"RSS 测试失败:{exc}") - alerts.add_warning("RSS", "RSS 测试执行失败", str(exc)) - _update_dashboard_sidebar() - - st.divider() - days = int( - st.number_input( - "检查窗口(天数)", - min_value=30, - max_value=10950, - value=365, - step=30, - ) - ) - LOGGER.debug("检查窗口天数=%s", days, extra=LOG_EXTRA) - cfg = get_config() - force_refresh = st.checkbox( - "强制刷新数据(关闭增量跳过)", - value=cfg.force_refresh, - help="勾选后将重新拉取所选区间全部数据", - ) - if force_refresh != cfg.force_refresh: - cfg.force_refresh = force_refresh - LOGGER.info("更新 force_refresh=%s", force_refresh, extra=LOG_EXTRA) - save_config() - - if st.button("执行手动数据同步"): - LOGGER.info("点击执行手动数据同步按钮", extra=LOG_EXTRA) - progress_bar = st.progress(0.0) - status_placeholder = st.empty() - log_placeholder = st.empty() - messages: list[str] = [] - - def hook(message: str, value: float) -> None: - progress_bar.progress(min(max(value, 0.0), 1.0)) - status_placeholder.write(message) - messages.append(message) - LOGGER.debug("手动数据同步进度:%s -> %.2f", message, value, extra=LOG_EXTRA) - - with st.spinner("正在执行手动数据同步..."): - try: - report = run_boot_check( - days=days, - progress_hook=hook, - force_refresh=force_refresh, - ) - LOGGER.info("手动数据同步成功", extra=LOG_EXTRA) - st.success("手动数据同步完成,以下为数据覆盖摘要。") - st.json(report.to_dict()) - if messages: - log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages)) - except Exception as exc: # noqa: BLE001 - LOGGER.exception("手动数据同步失败", extra=LOG_EXTRA) - st.error(f"手动数据同步失败:{exc}") - alerts.add_warning("数据同步", "手动数据同步失败", str(exc)) - _update_dashboard_sidebar() - if messages: - log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages)) - finally: - progress_bar.progress(1.0) - - st.divider() - st.subheader("LLM 接口测试") - st.json(llm_config_snapshot()) - llm_prompt = st.text_area("测试 Prompt", value="请概述今天的市场重点。", height=160) - system_prompt = st.text_area( - "System Prompt (可选)", - value="你是一名量化策略研究助手,用简洁中文回答。", - height=100, - ) - if st.button("执行 LLM 测试"): - with st.spinner("正在调用 LLM..."): - try: - response = run_llm(llm_prompt, system=system_prompt or None) - except Exception as exc: # noqa: BLE001 - LOGGER.exception("LLM 测试失败", extra=LOG_EXTRA) - st.error(f"LLM 调用失败:{exc}") - else: - LOGGER.info("LLM 测试成功", extra=LOG_EXTRA) - st.success("LLM 调用成功,以下为返回内容:") - st.write(response) - - -def render_log_viewer() -> None: - """渲染日志钻取与历史对比视图页面。""" - LOGGER.info("渲染日志视图页面", extra=LOG_EXTRA) - st.header("日志钻取与历史对比") - st.write("查看系统运行日志,支持时间范围筛选、关键词搜索和历史对比功能。") - - col1, col2 = st.columns(2) - with col1: - start_date = st.date_input("开始日期", value=date.today() - timedelta(days=7)) - with col2: - end_date = st.date_input("结束日期", value=date.today()) - - log_levels = ["ALL", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] - selected_level = st.selectbox("日志级别", log_levels, index=1) - - search_query = st.text_input("搜索关键词") - - with db_session(read_only=True) as conn: - stages = [row["stage"] for row in conn.execute("SELECT DISTINCT stage FROM run_log").fetchall()] - stages = [s for s in stages if s] - stages.insert(0, "ALL") - selected_stage = st.selectbox("执行阶段", stages) - - with st.spinner("加载日志数据中..."): - try: - with db_session(read_only=True) as conn: - query_parts = ["SELECT ts, stage, level, msg FROM run_log WHERE 1=1"] - params: list[object] = [] - - start_ts = f"{start_date.isoformat()}T00:00:00Z" - end_ts = f"{end_date.isoformat()}T23:59:59Z" - query_parts.append("AND ts BETWEEN ? AND ?") - params.extend([start_ts, end_ts]) - - if selected_level != "ALL": - query_parts.append("AND level = ?") - params.append(selected_level) - - if search_query: - query_parts.append("AND msg LIKE ?") - params.append(f"%{search_query}%") - - if selected_stage != "ALL": - query_parts.append("AND stage = ?") - params.append(selected_stage) - - query_parts.append("ORDER BY ts DESC") - - query = " ".join(query_parts) - rows = conn.execute(query, params).fetchall() - - if rows: - rows_dict = [{key: row[key] for key in row.keys()} for row in rows] - log_df = pd.DataFrame(rows_dict) - log_df["ts"] = pd.to_datetime(log_df["ts"]).dt.strftime("%Y-%m-%d %H:%M:%S") - for col in log_df.columns: - log_df[col] = log_df[col].astype(str) - else: - log_df = pd.DataFrame(columns=["ts", "stage", "level", "msg"]) - - st.dataframe( - log_df, - hide_index=True, - width="stretch", - column_config={ - "ts": st.column_config.TextColumn("时间"), - "stage": st.column_config.TextColumn("执行阶段"), - "level": st.column_config.TextColumn("日志级别"), - "msg": st.column_config.TextColumn("日志消息", width="large"), - }, - ) - - if not log_df.empty: - csv_data = log_df.to_csv(index=False).encode("utf-8") - st.download_button( - label="下载日志CSV", - data=csv_data, - file_name=f"logs_{start_date}_{end_date}.csv", - mime="text/csv", - key="download_logs", - ) - - json_data = log_df.to_json(orient="records", force_ascii=False, indent=2) - st.download_button( - label="下载日志JSON", - data=json_data, - file_name=f"logs_{start_date}_{end_date}.json", - mime="application/json", - key="download_logs_json", - ) - except Exception as exc: # noqa: BLE001 - LOGGER.exception("加载日志失败", extra=LOG_EXTRA) - st.error(f"加载日志数据失败:{exc}") - - st.subheader("历史对比") - st.write("选择两个时间点的日志进行对比分析。") - - col3, col4 = st.columns(2) - with col3: - compare_date1 = st.date_input("对比日期1", value=date.today() - timedelta(days=1)) - with col4: - compare_date2 = st.date_input("对比日期2", value=date.today()) - - comparison_stage = st.selectbox("对比阶段", stages, key="compare_stage") - st.write("选择需要比较的日志数量。") - compare_limit = st.slider("对比日志数量", min_value=10, max_value=200, value=50, step=10) - - if st.button("生成历史对比报告"): - with st.spinner("生成对比报告中..."): - try: - with db_session(read_only=True) as conn: - def load_logs(d: date) -> pd.DataFrame: - start_ts = f"{d.isoformat()}T00:00:00Z" - end_ts = f"{d.isoformat()}T23:59:59Z" - query = ["SELECT ts, level, msg FROM run_log WHERE ts BETWEEN ? AND ?"] - params: list[object] = [start_ts, end_ts] - if comparison_stage != "ALL": - query.append("AND stage = ?") - params.append(comparison_stage) - query.append("ORDER BY ts DESC LIMIT ?") - params.append(compare_limit) - sql = " ".join(query) - rows = conn.execute(sql, params).fetchall() - if not rows: - return pd.DataFrame(columns=["ts", "level", "msg"]) - df = pd.DataFrame([{k: row[k] for k in row.keys()} for row in rows]) - df["ts"] = pd.to_datetime(df["ts"]).dt.strftime("%Y-%m-%d %H:%M:%S") - return df - - df1 = load_logs(compare_date1) - df2 = load_logs(compare_date2) - - if df1.empty and df2.empty: - st.info("选定日期暂无日志可对比。") - else: - st.write("### 对比结果") - col_a, col_b = st.columns(2) - with col_a: - st.write(f"{compare_date1} 日日志") - st.dataframe(df1, hide_index=True, width="stretch") - with col_b: - st.write(f"{compare_date2} 日日志") - st.dataframe(df2, hide_index=True, width="stretch") - - summary = { - "日期1日志条数": int(len(df1)), - "日期2日志条数": int(len(df2)), - "新增日志条数": max(len(df2) - len(df1), 0), - } - st.write("摘要:", summary) - except Exception as exc: # noqa: BLE001 - LOGGER.exception("历史对比生成失败", extra=LOG_EXTRA) - st.error(f"生成历史对比失败:{exc}") - - -def render_data_settings() -> None: - """渲染数据源配置界面.""" - st.subheader("Tushare 数据源") - cfg = get_config() - - col1, col2 = st.columns(2) - with col1: - tushare_token = st.text_input( - "Tushare Token", - value=cfg.tushare_token or "", - type="password", - help="从 tushare.pro 获取的 API token" - ) - - with col2: - auto_update = st.checkbox( - "启动时自动更新数据", - value=cfg.auto_update_data, - help="启动应用时自动检查并更新数据" - ) - - update_interval = st.slider( - "数据更新间隔(天)", - min_value=1, - max_value=30, - value=cfg.data_update_interval, - help="自动更新时检查的数据时间范围" - ) - - if st.button("保存数据源配置"): - cfg.tushare_token = tushare_token - cfg.auto_update_data = auto_update - cfg.data_update_interval = update_interval - save_config(cfg) - st.success("数据源配置已更新!") - - st.divider() - st.subheader("数据更新记录") - - with db_session() as session: - df = pd.read_sql_query( - """ - SELECT job_type, status, created_at, updated_at, error_msg - FROM fetch_jobs - ORDER BY created_at DESC - LIMIT 50 - """, - session - ) - - if not df.empty: - df["duration"] = (df["updated_at"] - df["created_at"]).dt.total_seconds().round(2) - df = df.drop(columns=["updated_at"]) - df = df.rename(columns={ - "job_type": "数据类型", - "status": "状态", - "created_at": "开始时间", - "error_msg": "错误信息", - "duration": "耗时(秒)" - }) - st.dataframe(df, width='stretch') - else: - st.info("暂无数据更新记录") +from app.utils.config import get_config def main() -> None: LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA) st.set_page_config(page_title="多智能体个人投资助理", layout="wide") - - # 确保数据库表已创建 - from app.data.schema import initialize_database + initialize_database() - - # 检查是否需要自动更新数据 + cfg = get_config() if cfg.auto_update_data: LOGGER.info("检测到自动更新数据选项已启用,开始执行数据拉取", extra=LOG_EXTRA) try: - # 初始化数据库 - from app.data.schema import initialize_database - initialize_database() - - # 执行开机检查(包含数据拉取) - from app.ingest.checker import run_boot_check with st.spinner("正在自动更新数据..."): def progress_hook(message: str, progress: float) -> None: st.write(f"📊 {message} ({progress:.1%})") - + report = run_boot_check( - days=30, # 最近30天 + days=30, auto_fetch=True, progress_hook=progress_hook, - force_refresh=False + force_refresh=False, ) - - # 执行RSS新闻拉取 - from app.ingest.rss import ingest_configured_rss rss_count = ingest_configured_rss(hours_back=24, max_items_per_feed=50) - - LOGGER.info("自动数据更新完成:日线数据覆盖%s-%s,RSS新闻%s条", - report.start, report.end, rss_count, extra=LOG_EXTRA) + LOGGER.info( + "自动数据更新完成:日线数据覆盖%s-%s,RSS新闻%s条", + report.start, + report.end, + rss_count, + extra=LOG_EXTRA, + ) st.success(f"✅ 自动数据更新完成:获取RSS新闻 {rss_count} 条") - - except Exception as exc: + except Exception as exc: # noqa: BLE001 LOGGER.exception("自动数据更新失败", extra=LOG_EXTRA) st.error(f"❌ 自动数据更新失败:{exc}") - + render_global_dashboard() + tabs = st.tabs(["今日计划", "投资池/仓位", "回测与复盘", "行情可视化", "日志钻取", "数据与设置", "自检测试"]) LOGGER.debug( "Tabs 初始化完成:%s", - ["今日计划", "回测与复盘", "行情可视化", "日志钻取", "数据与设置", "自检测试"], + ["今日计划", "投资池/仓位", "回测与复盘", "行情可视化", "日志钻取", "数据与设置", "自检测试"], extra=LOG_EXTRA, ) + with tabs[0]: render_today_plan() with tabs[1]: @@ -2998,17 +85,12 @@ def main() -> None: with tabs[5]: st.header("系统设置") settings_tabs = st.tabs(["配置概览", "LLM 设置", "投资组合", "数据源"]) - with settings_tabs[0]: render_config_overview() - with settings_tabs[1]: render_llm_settings() - with settings_tabs[2]: - from app.ui.portfolio_config import render_portfolio_config render_portfolio_config() - with settings_tabs[3]: render_data_settings() with tabs[6]: diff --git a/app/ui/views/__init__.py b/app/ui/views/__init__.py new file mode 100644 index 0000000..e4f945a --- /dev/null +++ b/app/ui/views/__init__.py @@ -0,0 +1,24 @@ +"""View modules for Streamlit UI tabs.""" + +from .today import render_today_plan +from .pool import render_pool_overview +from .backtest import render_backtest_review +from .market import render_market_visualization +from .logs import render_log_viewer +from .settings import render_config_overview, render_llm_settings, render_data_settings +from .tests import render_tests +from .dashboard import render_global_dashboard, update_dashboard_sidebar + +__all__ = [ + "render_today_plan", + "render_pool_overview", + "render_backtest_review", + "render_market_visualization", + "render_log_viewer", + "render_config_overview", + "render_llm_settings", + "render_data_settings", + "render_tests", + "render_global_dashboard", + "update_dashboard_sidebar", +] diff --git a/app/ui/views/backtest.py b/app/ui/views/backtest.py new file mode 100644 index 0000000..6f8513a --- /dev/null +++ b/app/ui/views/backtest.py @@ -0,0 +1,790 @@ +"""回测与复盘相关视图。""" +from __future__ import annotations + +import json +import uuid +from dataclasses import asdict +from datetime import date +from typing import Dict, List, Optional + +import numpy as np +import pandas as pd +import plotly.express as px +import requests +from requests.exceptions import RequestException +import streamlit as st + +from app.agents.base import AgentContext +from app.agents.game import Decision +from app.agents.registry import default_agents +from app.backtest.decision_env import DecisionEnv, ParameterSpec +from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig, run_backtest +from app.ingest.checker import run_boot_check +from app.ingest.tushare import run_ingestion +from app.llm.client import run_llm +from app.llm.metrics import reset as reset_llm_metrics +from app.llm.metrics import snapshot as snapshot_llm_metrics +from app.utils import alerts +from app.utils.config import get_config, save_config +from app.utils.tuning import log_tuning_result + +from app.ui.shared import LOGGER, LOG_EXTRA, default_backtest_range +from app.ui.views.dashboard import update_dashboard_sidebar + +_DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result" +_DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results" + +def render_backtest_review() -> None: + """渲染回测执行、调参与结果复盘页面。""" + st.header("回测与复盘") + st.caption("1. 基于历史数据复盘当前策略;2. 借助强化学习/调参探索更优参数组合。") + app_cfg = get_config() + default_start, default_end = default_backtest_range(window_days=60) + LOGGER.debug( + "回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s", + default_start, + default_end, + "000001.SZ", + 0.035, + -0.015, + 10, + extra=LOG_EXTRA, + ) + + st.markdown("### 回测参数") + col1, col2 = st.columns(2) + start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date") + end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date") + universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe") + col_target, col_stop, col_hold = st.columns(3) + target = col_target.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f", key="bt_target") + stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f", key="bt_stop") + hold_days = col_hold.number_input("持有期(交易日)", value=10, step=1, key="bt_hold_days") + LOGGER.debug( + "当前回测表单输入:start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s", + start_date, + end_date, + universe_text, + target, + stop, + hold_days, + extra=LOG_EXTRA, + ) + + tab_backtest, tab_rl = st.tabs(["回测验证", "强化学习调参"]) + + with tab_backtest: + st.markdown("#### 回测执行") + if st.button("运行回测", key="bt_run_button"): + LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA) + decision_log_container = st.container() + status_box = st.status("准备执行回测...", expanded=True) + llm_stats_placeholder = st.empty() + decision_entries: List[str] = [] + + def _decision_callback(ts_code: str, trade_dt: date, ctx: AgentContext, decision: Decision) -> None: + ts_label = trade_dt.isoformat() + summary = "" + for dept_decision in decision.department_decisions.values(): + if getattr(dept_decision, "summary", ""): + summary = str(dept_decision.summary) + break + entry_lines = [ + f"**{ts_label} {ts_code}** → {decision.action.value} (信心 {decision.confidence:.2f})", + ] + if summary: + entry_lines.append(f"摘要:{summary}") + dep_highlights = [] + for dept_code, dept_decision in decision.department_decisions.items(): + dep_highlights.append( + f"{dept_code}:{dept_decision.action.value}({dept_decision.confidence:.2f})" + ) + if dep_highlights: + entry_lines.append("部门意见:" + ";".join(dep_highlights)) + decision_entries.append(" \n".join(entry_lines)) + decision_log_container.markdown("\n\n".join(decision_entries[-200:])) + status_box.write(f"{ts_label} {ts_code} → {decision.action.value} (信心 {decision.confidence:.2f})") + stats = snapshot_llm_metrics() + llm_stats_placeholder.json( + { + "LLM 调用次数": stats.get("total_calls", 0), + "Prompt Tokens": stats.get("total_prompt_tokens", 0), + "Completion Tokens": stats.get("total_completion_tokens", 0), + "按 Provider": stats.get("provider_calls", {}), + "按模型": stats.get("model_calls", {}), + } + ) + update_dashboard_sidebar(stats) + + reset_llm_metrics() + status_box.update(label="执行回测中...", state="running") + try: + universe = [code.strip() for code in universe_text.split(',') if code.strip()] + LOGGER.info( + "回测参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s", + start_date, + end_date, + universe, + target, + stop, + hold_days, + extra=LOG_EXTRA, + ) + backtest_cfg = BtConfig( + id="streamlit_demo", + name="Streamlit Demo Strategy", + start_date=start_date, + end_date=end_date, + universe=universe, + params={ + "target": target, + "stop": stop, + "hold_days": int(hold_days), + }, + ) + result = run_backtest(backtest_cfg, decision_callback=_decision_callback) + LOGGER.info( + "回测完成:nav_records=%s trades=%s", + len(result.nav_series), + len(result.trades), + extra=LOG_EXTRA, + ) + status_box.update(label="回测执行完成", state="complete") + st.success("回测执行完成,详见下方结果与统计。") + metrics = snapshot_llm_metrics() + llm_stats_placeholder.json( + { + "LLM 调用次数": metrics.get("total_calls", 0), + "Prompt Tokens": metrics.get("total_prompt_tokens", 0), + "Completion Tokens": metrics.get("total_completion_tokens", 0), + "按 Provider": metrics.get("provider_calls", {}), + "按模型": metrics.get("model_calls", {}), + } + ) + update_dashboard_sidebar(metrics) + st.session_state["backtest_last_result"] = {"nav_records": result.nav_series, "trades": result.trades} + st.json(st.session_state["backtest_last_result"]) + except Exception as exc: # noqa: BLE001 + LOGGER.exception("回测执行失败", extra=LOG_EXTRA) + status_box.update(label="回测执行失败", state="error") + st.error(f"回测执行失败:{exc}") + + last_result = st.session_state.get("backtest_last_result") + if last_result: + st.markdown("#### 最近回测输出") + st.json(last_result) + + st.divider() + # ADD: Comparison view for multiple backtest configurations + with st.expander("历史回测结果对比", expanded=False): + st.caption("从历史回测配置中选择多个进行净值曲线与指标对比。") + normalize_to_one = st.checkbox("归一化到 1 起点", value=True, key="bt_cmp_normalize") + use_log_y = st.checkbox("对数坐标", value=False, key="bt_cmp_log_y") + metric_options = ["总收益", "最大回撤", "交易数", "平均换手", "风险事件"] + selected_metrics = st.multiselect("显示指标列", metric_options, default=metric_options, key="bt_cmp_metrics") + try: + with db_session(read_only=True) as conn: + cfg_rows = conn.execute( + "SELECT id, name FROM bt_config ORDER BY rowid DESC LIMIT 50" + ).fetchall() + except Exception: # noqa: BLE001 + LOGGER.exception("读取 bt_config 失败", extra=LOG_EXTRA) + cfg_rows = [] + cfg_options = [f"{row['id']} | {row['name']}" for row in cfg_rows] + selected_labels = st.multiselect("选择配置", cfg_options, default=cfg_options[:2], key="bt_cmp_configs") + selected_ids = [label.split(" | ")[0].strip() for label in selected_labels] + nav_df = pd.DataFrame() + rpt_df = pd.DataFrame() + if selected_ids: + try: + with db_session(read_only=True) as conn: + nav_df = pd.read_sql_query( + "SELECT cfg_id, trade_date, nav FROM bt_nav WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))), + conn, + params=tuple(selected_ids), + ) + rpt_df = pd.read_sql_query( + "SELECT cfg_id, summary FROM bt_report WHERE cfg_id IN (%s)" % (",".join(["?"]*len(selected_ids))), + conn, + params=tuple(selected_ids), + ) + except Exception: # noqa: BLE001 + LOGGER.exception("读取回测结果失败", extra=LOG_EXTRA) + st.error("读取回测结果失败") + nav_df = pd.DataFrame() + rpt_df = pd.DataFrame() + if not nav_df.empty: + try: + nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"], errors="coerce") + # ADD: date window filter + overall_min = pd.to_datetime(nav_df["trade_date"].min()).date() + overall_max = pd.to_datetime(nav_df["trade_date"].max()).date() + col_d1, col_d2 = st.columns(2) + start_filter = col_d1.date_input("起始日期", value=overall_min, key="bt_cmp_start") + end_filter = col_d2.date_input("结束日期", value=overall_max, key="bt_cmp_end") + if start_filter > end_filter: + start_filter, end_filter = end_filter, start_filter + mask = (nav_df["trade_date"].dt.date >= start_filter) & (nav_df["trade_date"].dt.date <= end_filter) + nav_df = nav_df.loc[mask] + pivot = nav_df.pivot_table(index="trade_date", columns="cfg_id", values="nav") + if normalize_to_one: + pivot = pivot.apply(lambda s: s / s.dropna().iloc[0] if s.dropna().size else s) + import plotly.graph_objects as go + fig = go.Figure() + for col in pivot.columns: + fig.add_trace(go.Scatter(x=pivot.index, y=pivot[col], mode="lines", name=str(col))) + fig.update_layout(height=300, margin=dict(l=10, r=10, t=30, b=10)) + if use_log_y: + fig.update_yaxes(type="log") + st.plotly_chart(fig, width='stretch') + # ADD: export pivot + try: + csv_buf = pivot.reset_index() + csv_buf.columns = ["trade_date"] + [str(c) for c in pivot.columns] + st.download_button( + "下载曲线(CSV)", + data=csv_buf.to_csv(index=False), + file_name="bt_nav_compare.csv", + mime="text/csv", + key="dl_nav_compare", + ) + except Exception: + pass + except Exception: # noqa: BLE001 + LOGGER.debug("绘制对比曲线失败", extra=LOG_EXTRA) + if not rpt_df.empty: + try: + metrics_rows: List[Dict[str, object]] = [] + for _, row in rpt_df.iterrows(): + cfg_id = row["cfg_id"] + try: + summary = json.loads(row["summary"]) if isinstance(row["summary"], str) else (row["summary"] or {}) + except json.JSONDecodeError: + summary = {} + record = { + "cfg_id": cfg_id, + "总收益": summary.get("total_return"), + "最大回撤": summary.get("max_drawdown"), + "交易数": summary.get("trade_count"), + "平均换手": summary.get("avg_turnover"), + "风险事件": summary.get("risk_events"), + } + metrics_rows.append({k: v for k, v in record.items() if (k == "cfg_id" or k in selected_metrics)}) + if metrics_rows: + dfm = pd.DataFrame(metrics_rows) + st.dataframe(dfm, hide_index=True, width='stretch') + try: + st.download_button( + "下载指标(CSV)", + data=dfm.to_csv(index=False), + file_name="bt_metrics_compare.csv", + mime="text/csv", + key="dl_metrics_compare", + ) + except Exception: + pass + except Exception: # noqa: BLE001 + LOGGER.debug("渲染指标表失败", extra=LOG_EXTRA) + else: + st.info("请选择至少一个配置进行对比。") + + + + with tab_rl: + st.caption("使用 DecisionEnv 对代理权重进行强化学习调参,支持单次与批量实验。") + with st.expander("离线调参实验 (DecisionEnv)", expanded=False): + st.caption( + "使用 DecisionEnv 对代理权重做离线调参。请选择需要优化的代理并设定权重范围," + "系统会运行一次回测并返回收益、回撤等指标。若 LLM 网络不可用,将返回失败标记。" + ) + + disable_departments = st.checkbox( + "禁用部门 LLM(仅规则代理,适合离线快速评估)", + value=True, + help="关闭部门调用后不依赖外部 LLM 网络,仅根据规则代理权重模拟。", + ) + + default_experiment_id = f"streamlit_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + experiment_id = st.text_input( + "实验 ID", + value=default_experiment_id, + help="用于在 tuning_results 表中区分不同实验。", + ) + strategy_label = st.text_input( + "策略说明", + value="DecisionEnv", + help="可选:为本次调参记录一个策略名称或备注。", + ) + + agent_objects = default_agents() + agent_names = [agent.name for agent in agent_objects] + if not agent_names: + st.info("暂无可调整的代理。") + else: + selected_agents = st.multiselect( + "选择调参的代理权重", + agent_names, + default=agent_names[:2], + key="decision_env_agents", + ) + + specs: List[ParameterSpec] = [] + action_values: List[float] = [] + range_valid = True + for idx, agent_name in enumerate(selected_agents): + col_min, col_max, col_action = st.columns([1, 1, 2]) + min_key = f"decision_env_min_{agent_name}" + max_key = f"decision_env_max_{agent_name}" + action_key = f"decision_env_action_{agent_name}" + default_min = 0.0 + default_max = 1.0 + min_val = col_min.number_input( + f"{agent_name} 最小权重", + min_value=0.0, + max_value=1.0, + value=default_min, + step=0.05, + key=min_key, + ) + max_val = col_max.number_input( + f"{agent_name} 最大权重", + min_value=0.0, + max_value=1.0, + value=default_max, + step=0.05, + key=max_key, + ) + if max_val <= min_val: + range_valid = False + action_val = col_action.slider( + f"{agent_name} 动作 (0-1)", + min_value=0.0, + max_value=1.0, + value=0.5, + step=0.01, + key=action_key, + ) + specs.append( + ParameterSpec( + name=f"weight_{agent_name}", + target=f"agent_weights.{agent_name}", + minimum=min_val, + maximum=max_val, + ) + ) + action_values.append(action_val) + + run_decision_env = st.button("执行单次调参", key="run_decision_env_button") + just_finished_single = False + if run_decision_env: + if not selected_agents: + st.warning("请至少选择一个代理进行调参。") + elif not range_valid: + st.error("请确保所有代理的最大权重大于最小权重。") + else: + LOGGER.info( + "离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s", + selected_agents, + action_values, + disable_departments, + extra=LOG_EXTRA, + ) + baseline_weights = cfg.agent_weights.as_dict() + for agent in agent_objects: + baseline_weights.setdefault(agent.name, 1.0) + + universe_env = [code.strip() for code in universe_text.split(',') if code.strip()] + if not universe_env: + st.error("请先指定至少一个股票代码。") + else: + bt_cfg_env = BtConfig( + id="decision_env_streamlit", + name="DecisionEnv Streamlit", + start_date=start_date, + end_date=end_date, + universe=universe_env, + params={ + "target": target, + "stop": stop, + "hold_days": int(hold_days), + }, + method=cfg.decision_method, + ) + env = DecisionEnv( + bt_config=bt_cfg_env, + parameter_specs=specs, + baseline_weights=baseline_weights, + disable_departments=disable_departments, + ) + env.reset() + LOGGER.debug( + "离线调参(单次)启动 DecisionEnv:cfg=%s 参数维度=%s", + bt_cfg_env, + len(specs), + extra=LOG_EXTRA, + ) + with st.spinner("正在执行离线调参……"): + try: + observation, reward, done, info = env.step(action_values) + LOGGER.info( + "离线调参(单次)完成,obs=%s reward=%.4f done=%s", + observation, + reward, + done, + extra=LOG_EXTRA, + ) + except Exception as exc: # noqa: BLE001 + LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA) + st.error(f"离线调参失败:{exc}") + st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None) + else: + if observation.get("failure"): + st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。") + st.json(observation) + st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None) + else: + resolved_experiment_id = experiment_id or str(uuid.uuid4()) + resolved_strategy = strategy_label or "DecisionEnv" + action_payload = { + name: value + for name, value in zip(selected_agents, action_values) + } + metrics_payload = dict(observation) + metrics_payload["reward"] = reward + log_success = False + try: + log_tuning_result( + experiment_id=resolved_experiment_id, + strategy=resolved_strategy, + action=action_payload, + reward=reward, + metrics=metrics_payload, + weights=info.get("weights", {}), + ) + except Exception: # noqa: BLE001 + LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA) + else: + log_success = True + LOGGER.info( + "离线调参(单次)日志写入成功:experiment=%s strategy=%s", + resolved_experiment_id, + resolved_strategy, + extra=LOG_EXTRA, + ) + st.session_state[_DECISION_ENV_SINGLE_RESULT_KEY] = { + "observation": dict(observation), + "reward": float(reward), + "weights": info.get("weights", {}), + "nav_series": info.get("nav_series"), + "trades": info.get("trades"), + "portfolio_snapshots": info.get("portfolio_snapshots"), + "portfolio_trades": info.get("portfolio_trades"), + "risk_breakdown": info.get("risk_breakdown"), + "selected_agents": list(selected_agents), + "action_values": list(action_values), + "experiment_id": resolved_experiment_id, + "strategy_label": resolved_strategy, + "logged": log_success, + } + just_finished_single = True + single_result = st.session_state.get(_DECISION_ENV_SINGLE_RESULT_KEY) + if single_result: + if just_finished_single: + st.success("离线调参完成") + else: + st.success("离线调参结果(最近一次运行)") + st.caption( + f"实验 ID:{single_result.get('experiment_id', '-') } | 策略:{single_result.get('strategy_label', 'DecisionEnv')}" + ) + observation = single_result.get("observation", {}) + reward = float(single_result.get("reward", 0.0)) + col_metrics = st.columns(4) + col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}") + col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}") + col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}") + col_metrics[3].metric("奖励", f"{reward:+.4f}") + + turnover_ratio = float(observation.get("turnover", 0.0) or 0.0) + turnover_value = float(observation.get("turnover_value", 0.0) or 0.0) + risk_count = float(observation.get("risk_count", 0.0) or 0.0) + col_metrics_extra = st.columns(3) + col_metrics_extra[0].metric("平均换手率", f"{turnover_ratio:.2%}") + col_metrics_extra[1].metric("成交额", f"{turnover_value:,.0f}") + col_metrics_extra[2].metric("风险事件数", f"{int(risk_count)}") + + weights_dict = single_result.get("weights") or {} + if weights_dict: + st.write("调参后权重:") + st.json(weights_dict) + if st.button("保存这些权重为默认配置", key="save_decision_env_weights_single"): + try: + cfg.agent_weights.update_from_dict(weights_dict) + save_config(cfg) + except Exception as exc: # noqa: BLE001 + LOGGER.exception("保存权重失败", extra={**LOG_EXTRA, "error": str(exc)}) + st.error(f"写入配置失败:{exc}") + else: + st.success("代理权重已写入 config.json") + + if single_result.get("logged"): + st.caption("调参结果已写入 tuning_results 表。") + + nav_series = single_result.get("nav_series") or [] + if nav_series: + try: + nav_df = pd.DataFrame(nav_series) + if {"trade_date", "nav"}.issubset(nav_df.columns): + nav_df = nav_df.sort_values("trade_date") + nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"]) + st.line_chart(nav_df.set_index("trade_date")["nav"], height=220) + except Exception: # noqa: BLE001 + LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA) + + trades = single_result.get("trades") or [] + if trades: + st.write("成交记录:") + st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch') + + snapshots = single_result.get("portfolio_snapshots") or [] + if snapshots: + with st.expander("投资组合快照", expanded=False): + st.dataframe(pd.DataFrame(snapshots), hide_index=True, width='stretch') + + portfolio_trades = single_result.get("portfolio_trades") or [] + if portfolio_trades: + with st.expander("组合成交明细", expanded=False): + st.dataframe(pd.DataFrame(portfolio_trades), hide_index=True, width='stretch') + + risk_breakdown = single_result.get("risk_breakdown") or {} + if risk_breakdown: + with st.expander("风险事件统计", expanded=False): + st.json(risk_breakdown) + + if st.button("清除单次调参结果", key="clear_decision_env_single"): + st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None) + st.success("已清除单次调参结果缓存。") + + st.divider() + st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。") + default_grid = "\n".join( + [ + ",".join(["0.2" for _ in specs]), + ",".join(["0.5" for _ in specs]), + ",".join(["0.8" for _ in specs]), + ] + ) if specs else "" + action_grid_raw = st.text_area( + "动作列表", + value=default_grid, + height=120, + key="decision_env_batch_actions", + ) + run_batch = st.button("批量执行调参", key="run_decision_env_batch") + batch_just_ran = False + if run_batch: + if not selected_agents: + st.warning("请先选择调参代理。") + elif not range_valid: + st.error("请确保所有代理的最大权重大于最小权重。") + else: + LOGGER.info( + "离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s", + selected_agents, + disable_departments, + extra=LOG_EXTRA, + ) + lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()] + if not lines: + st.warning("请在文本框中输入至少一组动作。") + else: + LOGGER.debug( + "离线调参(批量)原始输入=%s", + lines, + extra=LOG_EXTRA, + ) + parsed_actions: List[List[float]] = [] + for line in lines: + try: + values = [float(val.strip()) for val in line.split(',') if val.strip()] + except ValueError: + st.error(f"无法解析动作行:{line}") + parsed_actions = [] + break + if len(values) != len(specs): + st.error(f"动作维度不匹配(期望 {len(specs)} 个值):{line}") + parsed_actions = [] + break + parsed_actions.append(values) + if parsed_actions: + LOGGER.info( + "离线调参(批量)解析动作成功,数量=%s", + len(parsed_actions), + extra=LOG_EXTRA, + ) + baseline_weights = cfg.agent_weights.as_dict() + for agent in agent_objects: + baseline_weights.setdefault(agent.name, 1.0) + + universe_env = [code.strip() for code in universe_text.split(',') if code.strip()] + if not universe_env: + st.error("请先指定至少一个股票代码。") + else: + bt_cfg_env = BtConfig( + id="decision_env_streamlit_batch", + name="DecisionEnv Batch", + start_date=start_date, + end_date=end_date, + universe=universe_env, + params={ + "target": target, + "stop": stop, + "hold_days": int(hold_days), + }, + method=cfg.decision_method, + ) + env = DecisionEnv( + bt_config=bt_cfg_env, + parameter_specs=specs, + baseline_weights=baseline_weights, + disable_departments=disable_departments, + ) + results: List[Dict[str, object]] = [] + resolved_experiment_id = experiment_id or str(uuid.uuid4()) + resolved_strategy = strategy_label or "DecisionEnv" + LOGGER.debug( + "离线调参(批量)启动 DecisionEnv:cfg=%s 动作组=%s", + bt_cfg_env, + len(parsed_actions), + extra=LOG_EXTRA, + ) + with st.spinner("正在批量执行调参……"): + for idx, action_vals in enumerate(parsed_actions, start=1): + env.reset() + try: + observation, reward, done, info = env.step(action_vals) + except Exception as exc: # noqa: BLE001 + LOGGER.exception("批量调参失败", extra=LOG_EXTRA) + results.append( + { + "序号": idx, + "动作": action_vals, + "状态": "error", + "错误": str(exc), + } + ) + continue + if observation.get("failure"): + results.append( + { + "序号": idx, + "动作": action_vals, + "状态": "failure", + "奖励": -1.0, + } + ) + else: + LOGGER.info( + "离线调参(批量)第 %s 组完成,reward=%.4f obs=%s", + idx, + reward, + observation, + extra=LOG_EXTRA, + ) + action_payload = { + name: value + for name, value in zip(selected_agents, action_vals) + } + metrics_payload = dict(observation) + metrics_payload["reward"] = reward + weights_payload = info.get("weights", {}) + try: + log_tuning_result( + experiment_id=resolved_experiment_id, + strategy=resolved_strategy, + action=action_payload, + reward=reward, + metrics=metrics_payload, + weights=weights_payload, + ) + except Exception: # noqa: BLE001 + LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA) + results.append( + { + "序号": idx, + "动作": action_vals, + "状态": "ok", + "总收益": observation.get("total_return", 0.0), + "最大回撤": observation.get("max_drawdown", 0.0), + "波动率": observation.get("volatility", 0.0), + "奖励": reward, + "权重": weights_payload, + } + ) + st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = { + "results": results, + "selectable": [ + row + for row in results + if row.get("状态") == "ok" and row.get("权重") + ], + "experiment_id": resolved_experiment_id, + "strategy_label": resolved_strategy, + } + batch_just_ran = True + LOGGER.info( + "离线调参(批量)执行结束,总结果条数=%s", + len(results), + extra=LOG_EXTRA, + ) + batch_state = st.session_state.get(_DECISION_ENV_BATCH_RESULTS_KEY) + if batch_state: + results = batch_state.get("results") or [] + if results: + if batch_just_ran: + st.success("批量调参完成") + else: + st.success("批量调参结果(最近一次运行)") + st.caption( + f"实验 ID:{batch_state.get('experiment_id', '-') } | 策略:{batch_state.get('strategy_label', 'DecisionEnv')}" + ) + results_df = pd.DataFrame(results) + st.write("批量调参结果:") + st.dataframe(results_df, hide_index=True, width='stretch') + selectable = batch_state.get("selectable") or [] + if selectable: + option_labels = [ + f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}" + for row in selectable + ] + selected_label = st.selectbox( + "选择要保存的记录", + option_labels, + key="decision_env_batch_select", + ) + selected_row = None + for label, row in zip(option_labels, selectable): + if label == selected_label: + selected_row = row + break + if selected_row and st.button( + "保存所选权重为默认配置", + key="save_decision_env_weights_batch", + ): + try: + cfg.agent_weights.update_from_dict(selected_row.get("权重", {})) + save_config(cfg) + except Exception as exc: # noqa: BLE001 + LOGGER.exception("批量保存权重失败", extra={**LOG_EXTRA, "error": str(exc)}) + st.error(f"写入配置失败:{exc}") + else: + st.success( + f"已将序号 {selected_row['序号']} 的权重写入 config.json" + ) + else: + st.caption("暂无成功的结果可供保存。") + else: + st.caption("批量调参在最近一次执行中未产生结果。") + if st.button("清除批量调参结果", key="clear_decision_env_batch"): + st.session_state.pop(_DECISION_ENV_BATCH_RESULTS_KEY, None) + st.session_state.pop("decision_env_batch_select", None) + st.success("已清除批量调参结果缓存。") diff --git a/app/ui/views/dashboard.py b/app/ui/views/dashboard.py new file mode 100644 index 0000000..90edac5 --- /dev/null +++ b/app/ui/views/dashboard.py @@ -0,0 +1,147 @@ +"""Sidebar dashboard for the Streamlit UI.""" +from __future__ import annotations + +from typing import Dict, Optional + +import streamlit as st + +from app.llm.metrics import ( + recent_decisions as llm_recent_decisions, + register_listener as register_llm_metrics_listener, + snapshot as snapshot_llm_metrics, +) +from app.utils import alerts + +from app.ui.shared import LOGGER, LOG_EXTRA + +_DASHBOARD_CONTAINERS: Optional[tuple[object, object]] = None +_DASHBOARD_ELEMENTS: Optional[Dict[str, object]] = None +_SIDEBAR_LISTENER_ATTACHED = False +_WARNINGS_PLACEHOLDER = None + + +def _ensure_dashboard_elements(metrics_container: object, decisions_container: object) -> Dict[str, object]: + elements = { + "metrics_calls": metrics_container.metric, + "metrics_prompt": metrics_container.metric, + "metrics_completion": metrics_container.metric, + "provider_distribution": metrics_container.empty(), + "model_distribution": metrics_container.empty(), + "decisions_list": decisions_container.empty(), + } + return elements + + +def update_dashboard_sidebar(metrics: Optional[Dict[str, object]] = None) -> None: + """Refresh sidebar metrics and warnings.""" + global _DASHBOARD_CONTAINERS + global _DASHBOARD_ELEMENTS + global _WARNINGS_PLACEHOLDER + + containers = _DASHBOARD_CONTAINERS + if not containers: + return + metrics_container, decisions_container = containers + elements = _DASHBOARD_ELEMENTS + if elements is None: + elements = _ensure_dashboard_elements(metrics_container, decisions_container) + _DASHBOARD_ELEMENTS = elements + + if metrics is None: + metrics = snapshot_llm_metrics() + + elements["metrics_calls"]("LLM 调用", metrics.get("total_calls", 0)) + elements["metrics_prompt"]("Prompt Tokens", metrics.get("total_prompt_tokens", 0)) + elements["metrics_completion"]("Completion Tokens", metrics.get("total_completion_tokens", 0)) + + provider_calls = metrics.get("provider_calls", {}) + model_calls = metrics.get("model_calls", {}) + + provider_placeholder = elements["provider_distribution"] + provider_placeholder.empty() + if provider_calls: + provider_placeholder.json(provider_calls) + else: + provider_placeholder.info("暂无 Provider 分布数据。") + + model_placeholder = elements["model_distribution"] + model_placeholder.empty() + if model_calls: + model_placeholder.json(model_calls) + else: + model_placeholder.info("暂无模型分布数据。") + + decisions = metrics.get("recent_decisions") or llm_recent_decisions(10) + decisions_placeholder = elements["decisions_list"] + decisions_placeholder.empty() + if decisions: + lines = [] + for record in reversed(decisions[-10:]): + ts_code = record.get("ts_code") + trade_date = record.get("trade_date") + action = record.get("action") + confidence = record.get("confidence", 0.0) + summary = record.get("summary") + line = f"**{trade_date} {ts_code}** → {action} (置信度 {confidence:.2f})" + if summary: + line += f"\n{summary}" + lines.append(line) + decisions_placeholder.markdown("\n\n".join(lines), unsafe_allow_html=True) + else: + decisions_placeholder.info("暂无决策记录。执行回测或实时评估后可在此查看。") + + if _WARNINGS_PLACEHOLDER is not None: + _WARNINGS_PLACEHOLDER.empty() + with _WARNINGS_PLACEHOLDER.container(): + st.subheader("数据告警") + warn_list = alerts.get_warnings() + if warn_list: + lines = [] + for warning in warn_list[-10:]: + detail = warning.get("detail") + source = warning.get("source") + ts = warning.get("ts") + label = warning.get("label") + line = f"- **{source or '未知来源'}** {label or ''}" + if detail: + line += f":{detail}" + if ts: + line += f"({ts})" + lines.append(line) + st.markdown("\n".join(lines)) + else: + st.caption("暂无数据告警。") + + +def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None: + try: + update_dashboard_sidebar(metrics) + except Exception: # noqa: BLE001 + LOGGER.debug("侧边栏监听器刷新失败", exc_info=True, extra=LOG_EXTRA) + + +def render_global_dashboard() -> None: + """Render a persistent sidebar with realtime LLM stats and recent decisions.""" + + global _DASHBOARD_CONTAINERS + global _DASHBOARD_ELEMENTS + global _SIDEBAR_LISTENER_ATTACHED + global _WARNINGS_PLACEHOLDER + + warnings = alerts.get_warnings() + badge = f" ({len(warnings)})" if warnings else "" + st.sidebar.header(f"系统监控{badge}") + + metrics_container = st.sidebar.container() + decisions_container = st.sidebar.container() + st.sidebar.container() # legacy placeholder for layout spacing + warn_placeholder = st.sidebar.empty() + + _DASHBOARD_CONTAINERS = (metrics_container, decisions_container) + _DASHBOARD_ELEMENTS = _ensure_dashboard_elements(metrics_container, decisions_container) + _WARNINGS_PLACEHOLDER = warn_placeholder + + if not _SIDEBAR_LISTENER_ATTACHED: + register_llm_metrics_listener(_sidebar_metrics_listener) + _SIDEBAR_LISTENER_ATTACHED = True + update_dashboard_sidebar() diff --git a/app/ui/views/logs.py b/app/ui/views/logs.py new file mode 100644 index 0000000..16acef1 --- /dev/null +++ b/app/ui/views/logs.py @@ -0,0 +1,164 @@ +"""日志钻取视图。""" +from __future__ import annotations + +from datetime import date, timedelta +import pandas as pd +import streamlit as st + +from app.utils.db import db_session + +from app.ui.shared import LOGGER, LOG_EXTRA + +def render_log_viewer() -> None: + """渲染日志钻取与历史对比视图页面。""" + LOGGER.info("渲染日志视图页面", extra=LOG_EXTRA) + st.header("日志钻取与历史对比") + st.write("查看系统运行日志,支持时间范围筛选、关键词搜索和历史对比功能。") + + col1, col2 = st.columns(2) + with col1: + start_date = st.date_input("开始日期", value=date.today() - timedelta(days=7)) + with col2: + end_date = st.date_input("结束日期", value=date.today()) + + log_levels = ["ALL", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + selected_level = st.selectbox("日志级别", log_levels, index=1) + + search_query = st.text_input("搜索关键词") + + with db_session(read_only=True) as conn: + stages = [row["stage"] for row in conn.execute("SELECT DISTINCT stage FROM run_log").fetchall()] + stages = [s for s in stages if s] + stages.insert(0, "ALL") + selected_stage = st.selectbox("执行阶段", stages) + + with st.spinner("加载日志数据中..."): + try: + with db_session(read_only=True) as conn: + query_parts = ["SELECT ts, stage, level, msg FROM run_log WHERE 1=1"] + params: list[object] = [] + + start_ts = f"{start_date.isoformat()}T00:00:00Z" + end_ts = f"{end_date.isoformat()}T23:59:59Z" + query_parts.append("AND ts BETWEEN ? AND ?") + params.extend([start_ts, end_ts]) + + if selected_level != "ALL": + query_parts.append("AND level = ?") + params.append(selected_level) + + if search_query: + query_parts.append("AND msg LIKE ?") + params.append(f"%{search_query}%") + + if selected_stage != "ALL": + query_parts.append("AND stage = ?") + params.append(selected_stage) + + query_parts.append("ORDER BY ts DESC") + + query = " ".join(query_parts) + rows = conn.execute(query, params).fetchall() + + if rows: + rows_dict = [{key: row[key] for key in row.keys()} for row in rows] + log_df = pd.DataFrame(rows_dict) + log_df["ts"] = pd.to_datetime(log_df["ts"]).dt.strftime("%Y-%m-%d %H:%M:%S") + for col in log_df.columns: + log_df[col] = log_df[col].astype(str) + else: + log_df = pd.DataFrame(columns=["ts", "stage", "level", "msg"]) + + st.dataframe( + log_df, + hide_index=True, + width="stretch", + column_config={ + "ts": st.column_config.TextColumn("时间"), + "stage": st.column_config.TextColumn("执行阶段"), + "level": st.column_config.TextColumn("日志级别"), + "msg": st.column_config.TextColumn("日志消息", width="large"), + }, + ) + + if not log_df.empty: + csv_data = log_df.to_csv(index=False).encode("utf-8") + st.download_button( + label="下载日志CSV", + data=csv_data, + file_name=f"logs_{start_date}_{end_date}.csv", + mime="text/csv", + key="download_logs", + ) + + json_data = log_df.to_json(orient="records", force_ascii=False, indent=2) + st.download_button( + label="下载日志JSON", + data=json_data, + file_name=f"logs_{start_date}_{end_date}.json", + mime="application/json", + key="download_logs_json", + ) + except Exception as exc: # noqa: BLE001 + LOGGER.exception("加载日志失败", extra=LOG_EXTRA) + st.error(f"加载日志数据失败:{exc}") + + st.subheader("历史对比") + st.write("选择两个时间点的日志进行对比分析。") + + col3, col4 = st.columns(2) + with col3: + compare_date1 = st.date_input("对比日期1", value=date.today() - timedelta(days=1)) + with col4: + compare_date2 = st.date_input("对比日期2", value=date.today()) + + comparison_stage = st.selectbox("对比阶段", stages, key="compare_stage") + st.write("选择需要比较的日志数量。") + compare_limit = st.slider("对比日志数量", min_value=10, max_value=200, value=50, step=10) + + if st.button("生成历史对比报告"): + with st.spinner("生成对比报告中..."): + try: + with db_session(read_only=True) as conn: + def load_logs(d: date) -> pd.DataFrame: + start_ts = f"{d.isoformat()}T00:00:00Z" + end_ts = f"{d.isoformat()}T23:59:59Z" + query = ["SELECT ts, level, msg FROM run_log WHERE ts BETWEEN ? AND ?"] + params: list[object] = [start_ts, end_ts] + if comparison_stage != "ALL": + query.append("AND stage = ?") + params.append(comparison_stage) + query.append("ORDER BY ts DESC LIMIT ?") + params.append(compare_limit) + sql = " ".join(query) + rows = conn.execute(sql, params).fetchall() + if not rows: + return pd.DataFrame(columns=["ts", "level", "msg"]) + df = pd.DataFrame([{k: row[k] for k in row.keys()} for row in rows]) + df["ts"] = pd.to_datetime(df["ts"]).dt.strftime("%Y-%m-%d %H:%M:%S") + return df + + df1 = load_logs(compare_date1) + df2 = load_logs(compare_date2) + + if df1.empty and df2.empty: + st.info("选定日期暂无日志可对比。") + else: + st.write("### 对比结果") + col_a, col_b = st.columns(2) + with col_a: + st.write(f"{compare_date1} 日日志") + st.dataframe(df1, hide_index=True, width="stretch") + with col_b: + st.write(f"{compare_date2} 日日志") + st.dataframe(df2, hide_index=True, width="stretch") + + summary = { + "日期1日志条数": int(len(df1)), + "日期2日志条数": int(len(df2)), + "新增日志条数": max(len(df2) - len(df1), 0), + } + st.write("摘要:", summary) + except Exception as exc: # noqa: BLE001 + LOGGER.exception("历史对比生成失败", extra=LOG_EXTRA) + st.error(f"生成历史对比失败:{exc}") diff --git a/app/ui/views/market.py b/app/ui/views/market.py new file mode 100644 index 0000000..db97641 --- /dev/null +++ b/app/ui/views/market.py @@ -0,0 +1,115 @@ +"""行情可视化页面。""" +from __future__ import annotations + +from datetime import date, timedelta + +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +import streamlit as st + +from app.utils.db import db_session + +from app.ui.shared import LOGGER, LOG_EXTRA + + +def _load_stock_options(limit: int = 500) -> list[str]: + try: + with db_session(read_only=True) as conn: + rows = conn.execute( + """ + SELECT DISTINCT ts_code + FROM daily + ORDER BY trade_date DESC + LIMIT ? + """, + (limit,), + ).fetchall() + except Exception: # noqa: BLE001 + LOGGER.exception("加载股票列表失败", extra=LOG_EXTRA) + return [] + return [row["ts_code"] for row in rows] + + +def _parse_ts_code(selection: str) -> str: + return selection.split(" ", 1)[0] + + +def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame: + with db_session(read_only=True) as conn: + df = pd.read_sql_query( + """ + SELECT trade_date, open, high, low, close, vol, amount + FROM daily + WHERE ts_code = ? AND trade_date BETWEEN ? AND ? + ORDER BY trade_date + """, + conn, + params=(ts_code, start.strftime("%Y%m%d"), end.strftime("%Y%m%d")), + ) + if df.empty: + return df + df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y%m%d") + return df + + +def render_market_visualization() -> None: + st.header("行情可视化") + st.caption("按标的查看 K 线、成交量以及常用指标。") + + options = _load_stock_options() + if not options: + st.warning("暂未加载到可用的行情标的,请先执行数据同步。") + return + + selection = st.selectbox("选择标的", options, index=0) + ts_code = _parse_ts_code(selection) + col1, col2 = st.columns(2) + with col1: + start_date = st.date_input("开始日期", value=date.today() - timedelta(days=120)) + with col2: + end_date = st.date_input("结束日期", value=date.today()) + + if start_date > end_date: + st.error("开始日期不能晚于结束日期。") + return + + try: + df = _load_daily_frame(ts_code, start_date, end_date) + except Exception as exc: # noqa: BLE001 + LOGGER.exception("加载行情数据失败", extra=LOG_EXTRA) + st.error(f"加载行情数据失败:{exc}") + return + + if df.empty: + st.info("所选区间内无行情数据。") + return + + st.metric("最新收盘价", f"{df['close'].iloc[-1]:.2f}") + fig = go.Figure( + data=[ + go.Candlestick( + x=df["trade_date"], + open=df["open"], + high=df["high"], + low=df["low"], + close=df["close"], + name="K线", + ) + ] + ) + fig.update_layout(title=f"{ts_code} K线图", xaxis_title="日期", yaxis_title="价格") + st.plotly_chart(fig, use_container_width=True) + + fig_vol = px.bar(df, x="trade_date", y="vol", title="成交量") + st.plotly_chart(fig_vol, use_container_width=True) + + df_ma = df.copy() + df_ma["MA5"] = df_ma["close"].rolling(window=5).mean() + df_ma["MA20"] = df_ma["close"].rolling(window=20).mean() + df_ma["MA60"] = df_ma["close"].rolling(window=60).mean() + + fig_ma = px.line(df_ma, x="trade_date", y=["close", "MA5", "MA20", "MA60"], title="均线对比") + st.plotly_chart(fig_ma, use_container_width=True) + + st.dataframe(df, hide_index=True, width='stretch') diff --git a/app/ui/views/pool.py b/app/ui/views/pool.py new file mode 100644 index 0000000..5956704 --- /dev/null +++ b/app/ui/views/pool.py @@ -0,0 +1,162 @@ +"""投资池与仓位概览页面。""" +from __future__ import annotations + +import pandas as pd +import plotly.express as px +import streamlit as st + +from app.utils.db import db_session +from app.utils.portfolio import ( + get_latest_snapshot, + list_investment_pool, + list_positions, + list_recent_trades, +) + +from app.ui.shared import LOGGER, LOG_EXTRA, get_latest_trade_date + + +def render_pool_overview() -> None: + """单独的投资池与仓位概览页面(从今日计划中提取)。""" + LOGGER.info("渲染投资池与仓位概览页面", extra=LOG_EXTRA) + st.header("投资池与仓位概览") + + snapshot = get_latest_snapshot() + if snapshot: + col_a, col_b, col_c = st.columns(3) + if snapshot.total_value is not None: + col_a.metric("组合净值", f"{snapshot.total_value:,.2f}") + if snapshot.cash is not None: + col_b.metric("现金余额", f"{snapshot.cash:,.2f}") + if snapshot.invested_value is not None: + col_c.metric("持仓市值", f"{snapshot.invested_value:,.2f}") + detail_cols = st.columns(4) + if snapshot.unrealized_pnl is not None: + detail_cols[0].metric("浮盈", f"{snapshot.unrealized_pnl:,.2f}") + if snapshot.realized_pnl is not None: + detail_cols[1].metric("已实现盈亏", f"{snapshot.realized_pnl:,.2f}") + if snapshot.net_flow is not None: + detail_cols[2].metric("净流入", f"{snapshot.net_flow:,.2f}") + if snapshot.exposure is not None: + detail_cols[3].metric("风险敞口", f"{snapshot.exposure:.2%}") + if snapshot.notes: + st.caption(f"备注:{snapshot.notes}") + else: + st.info("暂无组合快照,请在执行回测或实盘同步后写入 portfolio_snapshots。") + + try: + latest_date = get_latest_trade_date() + candidates = list_investment_pool(trade_date=latest_date) + except Exception: # noqa: BLE001 + LOGGER.exception("加载候选池失败", extra=LOG_EXTRA) + candidates = [] + + if candidates: + candidate_df = pd.DataFrame( + [ + { + "交易日": item.trade_date, + "代码": item.ts_code, + "评分": item.score, + "状态": item.status, + "标签": "、".join(item.tags) if item.tags else "-", + "理由": item.rationale or "", + } + for item in candidates + ] + ) + st.write("候选投资池:") + st.dataframe(candidate_df, width='stretch', hide_index=True) + else: + st.caption("候选投资池暂无数据。") + + positions = list_positions(active_only=False) + if positions: + position_df = pd.DataFrame( + [ + { + "ID": pos.id, + "代码": pos.ts_code, + "开仓日": pos.opened_date, + "平仓日": pos.closed_date or "-", + "状态": pos.status, + "数量": pos.quantity, + "成本": pos.cost_price, + "现价": pos.market_price, + "市值": pos.market_value, + "浮盈": pos.unrealized_pnl, + "已实现": pos.realized_pnl, + "目标权重": pos.target_weight, + } + for pos in positions + ] + ) + st.write("组合持仓:") + st.dataframe(position_df, width='stretch', hide_index=True) + else: + st.caption("组合持仓暂无记录。") + + trades = list_recent_trades(limit=20) + if trades: + trades_df = pd.DataFrame(trades) + st.write("近期成交:") + st.dataframe(trades_df, width='stretch', hide_index=True) + else: + st.caption("近期成交暂无记录。") + + st.caption("数据来源:agent_utils、investment_pool、portfolio_positions、portfolio_trades、portfolio_snapshots。") + + if st.button("执行对比", type="secondary"): + with st.spinner("执行日志对比分析中..."): + try: + with db_session(read_only=True) as conn: + query_date1 = f"{compare_date1.isoformat()}T00:00:00Z" # type: ignore[name-defined] + query_date2 = f"{compare_date1.isoformat()}T23:59:59Z" # type: ignore[name-defined] + logs1 = conn.execute( + "SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level", + (query_date1, query_date2), + ).fetchall() + + query_date3 = f"{compare_date2.isoformat()}T00:00:00Z" # type: ignore[name-defined] + query_date4 = f"{compare_date2.isoformat()}T23:59:59Z" # type: ignore[name-defined] + logs2 = conn.execute( + "SELECT level, COUNT(*) as count FROM run_log WHERE ts BETWEEN ? AND ? GROUP BY level", + (query_date3, query_date4), + ).fetchall() + + df1 = pd.DataFrame(logs1, columns=["level", "count"]) + df1["date"] = compare_date1.strftime("%Y-%m-%d") # type: ignore[name-defined] + df2 = pd.DataFrame(logs2, columns=["level", "count"]) + df2["date"] = compare_date2.strftime("%Y-%m-%d") # type: ignore[name-defined] + + for df in (df1, df2): + for col in df.columns: + if col != "level": + df[col] = df[col].astype(object) + + compare_df = pd.concat([df1, df2]) + fig = px.bar( + compare_df, + x="level", + y="count", + color="date", + barmode="group", + title=f"日志级别分布对比 ({compare_date1} vs {compare_date2})", # type: ignore[name-defined] + ) + st.plotly_chart(fig, width='stretch') + + st.write("日志统计对比:") + date1_str = compare_date1.strftime("%Y%m%d") # type: ignore[name-defined] + date2_str = compare_date2.strftime("%Y%m%d") # type: ignore[name-defined] + merged_df = df1.merge( + df2, + on="level", + suffixes=(f"_{date1_str}", f"_{date2_str}"), + how="outer", + ).fillna(0) + st.dataframe(merged_df, hide_index=True, width="stretch") + except Exception as exc: # noqa: BLE001 + LOGGER.exception("日志对比失败", extra=LOG_EXTRA) + st.error(f"日志对比分析失败:{exc}") + + return diff --git a/app/ui/views/settings.py b/app/ui/views/settings.py new file mode 100644 index 0000000..0d2c738 --- /dev/null +++ b/app/ui/views/settings.py @@ -0,0 +1,603 @@ +"""系统设置相关视图。""" +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Dict, List, Optional + +import pandas as pd +import requests +from requests.exceptions import RequestException +import streamlit as st + +from app.llm.client import llm_config_snapshot +from app.utils.config import ( + ALLOWED_LLM_STRATEGIES, + DEFAULT_LLM_BASE_URLS, + DepartmentSettings, + LLMEndpoint, + LLMProvider, + get_config, + save_config, +) +from app.utils.db import db_session + +from app.ui.shared import LOGGER, LOG_EXTRA + +_MODEL_CACHE: Dict[str, Dict[str, object]] = {} +_CACHE_TTL_SECONDS = 300 + +def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]: + """Attempt to query provider API and return available model ids.""" + + base_url = (base_override or provider.base_url or DEFAULT_LLM_BASE_URLS.get(provider.key, "")).strip() + if not base_url: + return [], "请先填写 Base URL" + timeout = float(provider.default_timeout or 30.0) + mode = provider.mode or ("ollama" if provider.key == "ollama" else "openai") + + # ADD: simple cache by provider+base URL + cache_key = f"{provider.key}|{base_url}" + now = datetime.now() + cached = _MODEL_CACHE.get(cache_key) + if cached: + ts = cached.get("ts") + if isinstance(ts, float) and (now.timestamp() - ts) < _CACHE_TTL_SECONDS: + models = list(cached.get("models") or []) + return models, None + + try: + if mode == "ollama": + url = base_url.rstrip('/') + "/api/tags" + response = requests.get(url, timeout=timeout) + response.raise_for_status() + data = response.json() + models = [] + for item in data.get("models", []) or data.get("data", []): + name = item.get("name") or item.get("model") or item.get("tag") + if name: + models.append(str(name).strip()) + _MODEL_CACHE[cache_key] = {"ts": now.timestamp(), "models": sorted(set(models))} + return sorted(set(models)), None + + api_key = (api_override or provider.api_key or "").strip() + if not api_key: + return [], "缺少 API Key" + url = base_url.rstrip('/') + "/v1/models" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + response = requests.get(url, headers=headers, timeout=timeout) + response.raise_for_status() + payload = response.json() + models = [ + str(item.get("id")).strip() + for item in payload.get("data", []) + if item.get("id") + ] + _MODEL_CACHE[cache_key] = {"ts": now.timestamp(), "models": sorted(set(models))} + return sorted(set(models)), None + except RequestException as exc: # noqa: BLE001 + return [], f"HTTP 错误:{exc}" + except Exception as exc: # noqa: BLE001 + return [], f"解析失败:{exc}" + +def render_config_overview() -> None: + """Render a concise overview of persisted configuration values.""" + + LOGGER.info("渲染配置概览页", extra=LOG_EXTRA) + cfg = get_config() + + st.subheader("核心配置概览") + col1, col2, col3 = st.columns(3) + col1.metric("决策方式", cfg.decision_method.upper()) + col2.metric("自动更新数据", "启用" if cfg.auto_update_data else "关闭") + col3.metric("数据更新间隔(天)", cfg.data_update_interval) + + col4, col5, col6 = st.columns(3) + col4.metric("强制刷新", "开启" if cfg.force_refresh else "关闭") + col5.metric("TuShare Token", "已配置" if cfg.tushare_token else "未配置") + col6.metric("配置文件", cfg.data_paths.config_file.name) + st.caption(f"配置文件路径:{cfg.data_paths.config_file}") + + st.divider() + st.subheader("RSS 数据源状态") + rss_sources = cfg.rss_sources or {} + if rss_sources: + rows: List[Dict[str, object]] = [] + for name, payload in rss_sources.items(): + if isinstance(payload, dict): + rows.append( + { + "名称": name, + "启用": "是" if payload.get("enabled", True) else "否", + "URL": payload.get("url", "-"), + "关键词数": len(payload.get("keywords", []) or []), + } + ) + elif isinstance(payload, bool): + rows.append( + { + "名称": name, + "启用": "是" if payload else "否", + "URL": "-", + "关键词数": 0, + } + ) + if rows: + st.dataframe(pd.DataFrame(rows), hide_index=True, width="stretch") + else: + st.info("未配置 RSS 数据源。") + else: + st.info("未在配置文件中找到 RSS 数据源。") + + st.divider() + st.subheader("部门配置") + dept_rows: List[Dict[str, object]] = [] + for code, dept in cfg.departments.items(): + dept_rows.append( + { + "部门": dept.title or code, + "代码": code, + "权重": dept.weight, + "LLM 策略": dept.llm.strategy, + "模板": dept.prompt_template_id or f"{code}_dept", + "模板版本": dept.prompt_template_version or "(激活版本)", + } + ) + if dept_rows: + st.dataframe(pd.DataFrame(dept_rows), hide_index=True, width="stretch") + else: + st.info("尚未配置任何部门。") + + st.divider() + st.subheader("LLM 成本控制") + cost = cfg.llm_cost + col_a, col_b, col_c, col_d = st.columns(4) + col_a.metric("成本控制", "启用" if cost.enabled else "关闭") + col_b.metric("小时预算($)", f"{cost.hourly_budget:.2f}") + col_c.metric("日预算($)", f"{cost.daily_budget:.2f}") + col_d.metric("月预算($)", f"{cost.monthly_budget:.2f}") + + if cost.model_weights: + weight_rows = ( + pd.DataFrame( + [ + {"模型": model, "占比上限": f"{limit * 100:.0f}%"} + for model, limit in cost.model_weights.items() + ] + ) + ) + st.dataframe(weight_rows, hide_index=True, width="stretch") + else: + st.caption("未配置模型占比限制。") + + st.divider() + st.caption("提示:数据源、LLM 及投资组合设置可在对应标签页中调整。") + +def render_llm_settings() -> None: + cfg = get_config() + st.subheader("LLM 设置") + providers = cfg.llm_providers + provider_keys = sorted(providers.keys()) + st.caption("先在 Provider 中维护基础连接(URL、Key、模型),再为全局与各部门设置个性化参数。") + + provider_select_col, provider_manage_col = st.columns([3, 1]) + if provider_keys: + try: + default_provider = cfg.llm.primary.provider or provider_keys[0] + provider_index = provider_keys.index(default_provider) + except ValueError: + provider_index = 0 + selected_provider = provider_select_col.selectbox( + "选择 Provider", + provider_keys, + index=provider_index, + key="llm_provider_select", + ) + else: + selected_provider = None + provider_select_col.info("尚未配置 Provider,请先创建。") + + new_provider_name = provider_manage_col.text_input("新增 Provider", key="new_provider_name") + if provider_manage_col.button("创建 Provider", key="create_provider_btn"): + key = (new_provider_name or "").strip().lower() + if not key: + st.warning("请输入有效的 Provider 名称。") + elif key in providers: + st.warning(f"Provider {key} 已存在。") + else: + providers[key] = LLMProvider(key=key) + cfg.llm_providers = providers + save_config() + st.success(f"已创建 Provider {key}。") + st.rerun() + + if selected_provider: + provider_cfg = providers.get(selected_provider, LLMProvider(key=selected_provider)) + title_key = f"provider_title_{selected_provider}" + base_key = f"provider_base_{selected_provider}" + api_key_key = f"provider_api_{selected_provider}" + mode_key = f"provider_mode_{selected_provider}" + enabled_key = f"provider_enabled_{selected_provider}" + + title_val = st.text_input("备注名称", value=provider_cfg.title or "", key=title_key) + base_val = st.text_input("Base URL", value=provider_cfg.base_url or "", key=base_key, help="调用地址,例如:https://api.openai.com") + api_val = st.text_input("API Key", value=provider_cfg.api_key or "", key=api_key_key, type="password") + + enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key) + mode_val = st.selectbox("模式", options=["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key) + st.markdown("可用模型:") + if provider_cfg.models: + st.code("\n".join(provider_cfg.models), language="text") + else: + st.info("尚未获取模型列表,可点击下方按钮自动拉取。") + try: + cache_key = f"{selected_provider}|{(base_val or '').strip()}" + entry = _MODEL_CACHE.get(cache_key) + if entry and isinstance(entry.get("ts"), float): + ts = datetime.fromtimestamp(entry["ts"]).strftime("%Y-%m-%d %H:%M:%S") + st.caption(f"最近拉取时间:{ts}") + except Exception: + pass + + fetch_key = f"fetch_models_{selected_provider}" + if st.button("获取模型列表", key=fetch_key): + with st.spinner("正在获取模型列表..."): + models, error = _discover_provider_models(provider_cfg, base_val, api_val) + if error: + st.error(error) + else: + provider_cfg.models = models + if models and (not provider_cfg.default_model or provider_cfg.default_model not in models): + provider_cfg.default_model = models[0] + providers[selected_provider] = provider_cfg + cfg.llm_providers = providers + cfg.sync_runtime_llm() + save_config() + st.success(f"共获取 {len(models)} 个模型。") + st.rerun() + + if st.button("保存 Provider", key=f"save_provider_{selected_provider}"): + provider_cfg.title = title_val.strip() + provider_cfg.base_url = base_val.strip() + provider_cfg.api_key = api_val.strip() or None + provider_cfg.enabled = enabled_val + provider_cfg.mode = mode_val + providers[selected_provider] = provider_cfg + cfg.llm_providers = providers + cfg.sync_runtime_llm() + save_config() + st.success("Provider 已保存。") + st.session_state[title_key] = provider_cfg.title or "" + + provider_in_use = (cfg.llm.primary.provider == selected_provider) or any( + ep.provider == selected_provider for ep in cfg.llm.ensemble + ) + if not provider_in_use: + for dept in cfg.departments.values(): + if dept.llm.primary.provider == selected_provider or any(ep.provider == selected_provider for ep in dept.llm.ensemble): + provider_in_use = True + break + if st.button( + "删除 Provider", + key=f"delete_provider_{selected_provider}", + disabled=provider_in_use or len(providers) <= 1, + ): + providers.pop(selected_provider, None) + cfg.llm_providers = providers + cfg.sync_runtime_llm() + save_config() + st.success("Provider 已删除。") + st.rerun() + + st.markdown("##### 全局推理配置") + if not provider_keys: + st.warning("请先配置至少一个 Provider。") + else: + global_cfg = cfg.llm + primary = global_cfg.primary + try: + provider_index = provider_keys.index(primary.provider or provider_keys[0]) + except ValueError: + provider_index = 0 + selected_global_provider = st.selectbox( + "主模型 Provider", + provider_keys, + index=provider_index, + key="global_provider_select", + ) + provider_cfg = providers.get(selected_global_provider) + available_models = provider_cfg.models if provider_cfg else [] + default_model = primary.model or (provider_cfg.default_model if provider_cfg else None) + if available_models: + options = available_models + ["自定义"] + try: + model_index = available_models.index(default_model) + model_choice = st.selectbox("主模型", options, index=model_index, key="global_model_choice") + except ValueError: + model_choice = st.selectbox("主模型", options, index=len(options) - 1, key="global_model_choice") + if model_choice == "自定义": + model_val = st.text_input("自定义模型", value=default_model or "", key="global_model_custom").strip() + else: + model_val = model_choice + else: + model_val = st.text_input("主模型", value=default_model or "", key="global_model_custom").strip() + + temp_default = primary.temperature if primary.temperature is not None else (provider_cfg.default_temperature if provider_cfg else 0.2) + temp_val = st.slider("主模型温度", min_value=0.0, max_value=2.0, value=float(temp_default), step=0.05, key="global_temp") + timeout_default = primary.timeout if primary.timeout is not None else (provider_cfg.default_timeout if provider_cfg else 30.0) + timeout_val = st.number_input("主模型超时(秒)", min_value=5, max_value=300, value=int(timeout_default), step=5, key="global_timeout") + prompt_template_val = st.text_area( + "主模型 Prompt 模板(可选)", + value=primary.prompt_template or provider_cfg.prompt_template if provider_cfg else "", + height=120, + key="global_prompt_template", + ) + + strategy_val = st.selectbox("推理策略", sorted(ALLOWED_LLM_STRATEGIES), index=sorted(ALLOWED_LLM_STRATEGIES).index(global_cfg.strategy) if global_cfg.strategy in ALLOWED_LLM_STRATEGIES else 0, key="global_strategy") + show_ensemble = strategy_val != "single" + majority_threshold_val = st.number_input( + "多数投票门槛", + min_value=1, + max_value=10, + value=int(global_cfg.majority_threshold), + step=1, + key="global_majority", + disabled=not show_ensemble, + ) + if not show_ensemble: + majority_threshold_val = 1 + + ensemble_rows: List[Dict[str, str]] = [] + if show_ensemble: + ensemble_rows = [ + { + "provider": ep.provider, + "model": ep.model or "", + "temperature": "" if ep.temperature is None else f"{ep.temperature:.3f}", + "timeout": "" if ep.timeout is None else str(int(ep.timeout)), + "prompt_template": ep.prompt_template or "", + } + for ep in global_cfg.ensemble + ] or [{"provider": primary.provider or selected_global_provider, "model": "", "temperature": "", "timeout": "", "prompt_template": ""}] + + ensemble_editor = st.data_editor( + ensemble_rows, + num_rows="dynamic", + key="global_ensemble_editor", + width='stretch', + hide_index=True, + column_config={ + "provider": st.column_config.SelectboxColumn("Provider", options=provider_keys), + "model": st.column_config.TextColumn("模型"), + "temperature": st.column_config.TextColumn("温度"), + "timeout": st.column_config.TextColumn("超时(秒)"), + "prompt_template": st.column_config.TextColumn("Prompt 模板"), + }, + ) + if hasattr(ensemble_editor, "to_dict"): + ensemble_rows = ensemble_editor.to_dict("records") + else: + ensemble_rows = ensemble_editor + else: + st.info("当前策略为单模型,未启用协作模型。") + + if st.button("保存全局配置", key="save_global_llm"): + primary.provider = selected_global_provider + primary.model = model_val or None + primary.temperature = float(temp_val) + primary.timeout = float(timeout_val) + primary.prompt_template = prompt_template_val.strip() or None + primary.base_url = None + primary.api_key = None + + new_ensemble: List[LLMEndpoint] = [] + if show_ensemble: + for row in ensemble_rows: + provider_val = (row.get("provider") or "").strip().lower() + if not provider_val: + continue + model_raw = (row.get("model") or "").strip() or None + temp_raw = (row.get("temperature") or "").strip() + timeout_raw = (row.get("timeout") or "").strip() + prompt_raw = (row.get("prompt_template") or "").strip() + new_ensemble.append( + LLMEndpoint( + provider=provider_val, + model=model_raw, + temperature=float(temp_raw) if temp_raw else None, + timeout=float(timeout_raw) if timeout_raw else None, + prompt_template=prompt_raw or None, + ) + ) + cfg.llm.ensemble = new_ensemble + cfg.llm.strategy = strategy_val + cfg.llm.majority_threshold = int(majority_threshold_val) + cfg.sync_runtime_llm() + save_config() + st.success("全局 LLM 配置已保存。") + st.json(llm_config_snapshot()) + + st.markdown("##### 部门配置") + dept_settings = cfg.departments or {} + dept_rows = [ + { + "code": code, + "title": dept.title, + "description": dept.description, + "weight": float(dept.weight), + "strategy": dept.llm.strategy, + "majority_threshold": dept.llm.majority_threshold, + "provider": dept.llm.primary.provider or (provider_keys[0] if provider_keys else ""), + "model": dept.llm.primary.model or "", + "temperature": "" if dept.llm.primary.temperature is None else f"{dept.llm.primary.temperature:.3f}", + "timeout": "" if dept.llm.primary.timeout is None else str(int(dept.llm.primary.timeout)), + "prompt_template": dept.llm.primary.prompt_template or "", + } + for code, dept in sorted(dept_settings.items()) + ] + + if not dept_rows: + st.info("当前未配置部门,可在 config.json 中添加。") + dept_rows = [] + + dept_editor = st.data_editor( + dept_rows, + num_rows="fixed", + key="department_editor", + width='stretch', + hide_index=True, + column_config={ + "code": st.column_config.TextColumn("编码", disabled=True), + "title": st.column_config.TextColumn("名称"), + "description": st.column_config.TextColumn("说明"), + "weight": st.column_config.NumberColumn("权重", min_value=0.0, max_value=10.0, step=0.1), + "strategy": st.column_config.SelectboxColumn("策略", options=sorted(ALLOWED_LLM_STRATEGIES)), + "majority_threshold": st.column_config.NumberColumn("投票阈值", min_value=1, max_value=10, step=1), + "provider": st.column_config.SelectboxColumn("Provider", options=provider_keys or [""]), + "model": st.column_config.TextColumn("模型"), + "temperature": st.column_config.TextColumn("温度"), + "timeout": st.column_config.TextColumn("超时(秒)"), + "prompt_template": st.column_config.TextColumn("Prompt 模板"), + }, + ) + + if hasattr(dept_editor, "to_dict"): + dept_rows = dept_editor.to_dict("records") + else: + dept_rows = dept_editor + + col_reset, col_save = st.columns([1, 1]) + + if col_save.button("保存部门配置"): + updated_departments: Dict[str, DepartmentSettings] = {} + for row in dept_rows: + code = row.get("code") + if not code: + continue + existing = dept_settings.get(code) or DepartmentSettings(code=code, title=code) + existing.title = row.get("title") or existing.title + existing.description = row.get("description") or "" + try: + existing.weight = max(0.0, float(row.get("weight", existing.weight))) + except (TypeError, ValueError): + pass + + strategy_val = (row.get("strategy") or existing.llm.strategy).lower() + if strategy_val in ALLOWED_LLM_STRATEGIES: + existing.llm.strategy = strategy_val + if existing.llm.strategy == "single": + existing.llm.majority_threshold = 1 + existing.llm.ensemble = [] + else: + majority_raw = row.get("majority_threshold") + try: + majority_val = int(majority_raw) + if majority_val > 0: + existing.llm.majority_threshold = majority_val + except (TypeError, ValueError): + pass + + provider_val = (row.get("provider") or existing.llm.primary.provider or (provider_keys[0] if provider_keys else "ollama")).strip().lower() + model_val = (row.get("model") or "").strip() or None + temp_raw = (row.get("temperature") or "").strip() + timeout_raw = (row.get("timeout") or "").strip() + prompt_raw = (row.get("prompt_template") or "").strip() + + endpoint = existing.llm.primary or LLMEndpoint() + endpoint.provider = provider_val + endpoint.model = model_val + endpoint.temperature = float(temp_raw) if temp_raw else None + endpoint.timeout = float(timeout_raw) if timeout_raw else None + endpoint.prompt_template = prompt_raw or None + endpoint.base_url = None + endpoint.api_key = None + existing.llm.primary = endpoint + if existing.llm.strategy != "single": + existing.llm.ensemble = [] + + updated_departments[code] = existing + + if updated_departments: + cfg.departments = updated_departments + cfg.sync_runtime_llm() + save_config() + st.success("部门配置已更新。") + else: + st.warning("未能解析部门配置输入。") + + if col_reset.button("恢复默认部门"): + from app.utils.config import _default_departments + + cfg.departments = _default_departments() + cfg.sync_runtime_llm() + save_config() + st.success("已恢复默认部门配置。") + st.rerun() + + st.caption("部门配置存储为独立 LLM 参数,执行时会自动套用对应 Provider 的连接信息。") + +def render_data_settings() -> None: + """渲染数据源配置界面.""" + st.subheader("Tushare 数据源") + cfg = get_config() + + col1, col2 = st.columns(2) + with col1: + tushare_token = st.text_input( + "Tushare Token", + value=cfg.tushare_token or "", + type="password", + help="从 tushare.pro 获取的 API token" + ) + + with col2: + auto_update = st.checkbox( + "启动时自动更新数据", + value=cfg.auto_update_data, + help="启动应用时自动检查并更新数据" + ) + + update_interval = st.slider( + "数据更新间隔(天)", + min_value=1, + max_value=30, + value=cfg.data_update_interval, + help="自动更新时检查的数据时间范围" + ) + + if st.button("保存数据源配置"): + cfg.tushare_token = tushare_token + cfg.auto_update_data = auto_update + cfg.data_update_interval = update_interval + save_config(cfg) + st.success("数据源配置已更新!") + + st.divider() + st.subheader("数据更新记录") + + with db_session() as session: + df = pd.read_sql_query( + """ + SELECT job_type, status, created_at, updated_at, error_msg + FROM fetch_jobs + ORDER BY created_at DESC + LIMIT 50 + """, + session + ) + + if not df.empty: + df["duration"] = (df["updated_at"] - df["created_at"]).dt.total_seconds().round(2) + df = df.drop(columns=["updated_at"]) + df = df.rename(columns={ + "job_type": "数据类型", + "status": "状态", + "created_at": "开始时间", + "error_msg": "错误信息", + "duration": "耗时(秒)" + }) + st.dataframe(df, width='stretch') + else: + st.info("暂无数据更新记录") diff --git a/app/ui/views/tests.py b/app/ui/views/tests.py new file mode 100644 index 0000000..f5e0253 --- /dev/null +++ b/app/ui/views/tests.py @@ -0,0 +1,194 @@ +"""自检测试视图。""" +from __future__ import annotations + +from datetime import date + +import streamlit as st + +from app.data.schema import initialize_database +from app.ingest.checker import run_boot_check +from app.ingest.tushare import FetchJob, run_ingestion +from app.llm.client import llm_config_snapshot, run_llm +from app.utils import alerts +from app.utils.config import get_config, save_config + +from app.ui.shared import LOGGER, LOG_EXTRA +from app.ui.views.dashboard import update_dashboard_sidebar + +def render_tests() -> None: + LOGGER.info("渲染自检页面", extra=LOG_EXTRA) + st.header("自检测试") + st.write("用于快速检查数据库与数据拉取是否正常工作。") + + if st.button("测试数据库初始化"): + LOGGER.info("点击测试数据库初始化按钮", extra=LOG_EXTRA) + with st.spinner("正在检查数据库..."): + result = initialize_database() + if result.skipped: + LOGGER.info("数据库已存在,无需初始化", extra=LOG_EXTRA) + st.success("数据库已存在,检查通过。") + else: + LOGGER.info("数据库初始化完成,执行语句数=%s", result.executed, extra=LOG_EXTRA) + st.success(f"数据库初始化完成,共执行 {result.executed} 条语句。") + + st.divider() + + if st.button("测试 TuShare 拉取(示例 2024-01-01 至 2024-01-03)"): + LOGGER.info("点击示例 TuShare 拉取按钮", extra=LOG_EXTRA) + with st.spinner("正在调用 TuShare 接口..."): + try: + run_ingestion( + FetchJob( + name="streamlit_self_test", + start=date(2024, 1, 1), + end=date(2024, 1, 3), + ts_codes=("000001.SZ",), + ), + include_limits=False, + ) + LOGGER.info("示例 TuShare 拉取成功", extra=LOG_EXTRA) + st.success("TuShare 示例拉取完成,数据已写入数据库。") + except Exception as exc: # noqa: BLE001 + LOGGER.exception("示例 TuShare 拉取失败", extra=LOG_EXTRA) + st.error(f"拉取失败:{exc}") + alerts.add_warning("TuShare", "示例拉取失败", str(exc)) + update_dashboard_sidebar() + + st.info("注意:TuShare 拉取依赖网络与 Token,若环境未配置将出现错误提示。") + + st.divider() + + st.subheader("RSS 数据测试") + st.write("用于验证 RSS 配置是否能够正常抓取新闻并写入数据库。") + rss_url = st.text_input( + "测试 RSS 地址", + value="https://rsshub.app/cls/depth/1000", + help="留空则使用默认配置的全部 RSS 来源。", + ).strip() + rss_hours = int( + st.number_input( + "回溯窗口(小时)", + min_value=1, + max_value=168, + value=24, + step=6, + help="仅抓取最近指定小时内的新闻。", + ) + ) + rss_limit = int( + st.number_input( + "单源抓取条数", + min_value=1, + max_value=200, + value=50, + step=10, + ) + ) + if st.button("运行 RSS 测试"): + from app.ingest import rss as rss_ingest + + LOGGER.info( + "点击 RSS 测试按钮 rss_url=%s hours=%s limit=%s", + rss_url, + rss_hours, + rss_limit, + extra=LOG_EXTRA, + ) + with st.spinner("正在抓取 RSS 新闻..."): + try: + if rss_url: + items = rss_ingest.fetch_rss_feed( + rss_url, + hours_back=rss_hours, + max_items=rss_limit, + ) + count = rss_ingest.save_news_items(items) + else: + count = rss_ingest.ingest_configured_rss( + hours_back=rss_hours, + max_items_per_feed=rss_limit, + ) + st.success(f"RSS 测试完成,新增 {count} 条新闻记录。") + except Exception as exc: # noqa: BLE001 + LOGGER.exception("RSS 测试失败", extra=LOG_EXTRA) + st.error(f"RSS 测试失败:{exc}") + alerts.add_warning("RSS", "RSS 测试执行失败", str(exc)) + update_dashboard_sidebar() + + st.divider() + days = int( + st.number_input( + "检查窗口(天数)", + min_value=30, + max_value=10950, + value=365, + step=30, + ) + ) + LOGGER.debug("检查窗口天数=%s", days, extra=LOG_EXTRA) + cfg = get_config() + force_refresh = st.checkbox( + "强制刷新数据(关闭增量跳过)", + value=cfg.force_refresh, + help="勾选后将重新拉取所选区间全部数据", + ) + if force_refresh != cfg.force_refresh: + cfg.force_refresh = force_refresh + LOGGER.info("更新 force_refresh=%s", force_refresh, extra=LOG_EXTRA) + save_config() + + if st.button("执行手动数据同步"): + LOGGER.info("点击执行手动数据同步按钮", extra=LOG_EXTRA) + progress_bar = st.progress(0.0) + status_placeholder = st.empty() + log_placeholder = st.empty() + messages: list[str] = [] + + def hook(message: str, value: float) -> None: + progress_bar.progress(min(max(value, 0.0), 1.0)) + status_placeholder.write(message) + messages.append(message) + LOGGER.debug("手动数据同步进度:%s -> %.2f", message, value, extra=LOG_EXTRA) + + with st.spinner("正在执行手动数据同步..."): + try: + report = run_boot_check( + days=days, + progress_hook=hook, + force_refresh=force_refresh, + ) + LOGGER.info("手动数据同步成功", extra=LOG_EXTRA) + st.success("手动数据同步完成,以下为数据覆盖摘要。") + st.json(report.to_dict()) + if messages: + log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages)) + except Exception as exc: # noqa: BLE001 + LOGGER.exception("手动数据同步失败", extra=LOG_EXTRA) + st.error(f"手动数据同步失败:{exc}") + alerts.add_warning("数据同步", "手动数据同步失败", str(exc)) + update_dashboard_sidebar() + if messages: + log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages)) + finally: + progress_bar.progress(1.0) + + st.divider() + st.subheader("LLM 接口测试") + st.json(llm_config_snapshot()) + llm_prompt = st.text_area("测试 Prompt", value="请概述今天的市场重点。", height=160) + system_prompt = st.text_area( + "System Prompt (可选)", + value="你是一名量化策略研究助手,用简洁中文回答。", + height=100, + ) + if st.button("执行 LLM 测试"): + with st.spinner("正在调用 LLM..."): + try: + response = run_llm(llm_prompt, system=system_prompt or None) + except Exception as exc: # noqa: BLE001 + LOGGER.exception("LLM 测试失败", extra=LOG_EXTRA) + st.error(f"LLM 调用失败:{exc}") + else: + LOGGER.info("LLM 测试成功", extra=LOG_EXTRA) + st.success("LLM 调用成功,以下为返回内容:") + st.write(response) diff --git a/app/ui/views/today.py b/app/ui/views/today.py new file mode 100644 index 0000000..808a5a8 --- /dev/null +++ b/app/ui/views/today.py @@ -0,0 +1,677 @@ +"""今日计划页面视图。""" +from __future__ import annotations + +import json +from collections import Counter +from datetime import date, datetime, timedelta +from typing import Dict, List, Optional + +import numpy as np +import pandas as pd +import streamlit as st + +from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig +from app.utils.portfolio import list_investment_pool +from app.utils.db import db_session + +from app.ui.shared import ( + LOGGER, + LOG_EXTRA, + get_latest_trade_date, + get_query_params, + set_query_params, +) + + +def render_today_plan() -> None: + LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA) + st.header("今日计划") + latest_trade_date = get_latest_trade_date() + if latest_trade_date: + st.caption(f"最新交易日:{latest_trade_date.isoformat()}(统计数据请见左侧系统监控)") + else: + st.caption("统计与决策概览现已移至左侧'系统监控'侧栏。") + try: + with db_session(read_only=True) as conn: + date_rows = conn.execute( + """ + SELECT DISTINCT trade_date + FROM agent_utils + ORDER BY trade_date DESC + LIMIT 30 + """ + ).fetchall() + except Exception: # noqa: BLE001 + LOGGER.exception("加载 agent_utils 失败", extra=LOG_EXTRA) + st.warning("暂未写入部门/代理决策,请先运行回测或策略评估流程。") + return + + trade_dates = [row["trade_date"] for row in date_rows] + if not trade_dates: + st.info("暂无决策记录,完成一次回测后即可在此查看部门意见与投票结果。") + return + + query = get_query_params() + default_trade_date = query.get("date", [trade_dates[0]])[0] + try: + default_idx = trade_dates.index(default_trade_date) + except ValueError: + default_idx = 0 + trade_date = st.selectbox("交易日", trade_dates, index=default_idx) + + with db_session(read_only=True) as conn: + code_rows = conn.execute( + """ + SELECT DISTINCT ts_code + FROM agent_utils + WHERE trade_date = ? + ORDER BY ts_code + """, + (trade_date,), + ).fetchall() + symbols = [row["ts_code"] for row in code_rows] + + detail_tab, assistant_tab = st.tabs(["标的详情", "投资助理模式"]) + with assistant_tab: + _render_today_plan_assistant_view(trade_date) + + with detail_tab: + if not symbols: + st.info("所选交易日暂无 agent_utils 记录。") + else: + _render_today_plan_symbol_view(trade_date, symbols, query) + + +def _render_today_plan_assistant_view(trade_date: str | int | date) -> None: + st.info("已开启投资助理模式:以下内容为组合级(去标的)建议,不包含任何具体标的代码。") + try: + candidates = list_investment_pool(trade_date=trade_date) + if candidates: + scores = [float(item.score or 0.0) for item in candidates] + statuses = [item.status or "UNKNOWN" for item in candidates] + tags: List[str] = [] + rationales: List[str] = [] + for item in candidates: + if getattr(item, "tags", None): + tags.extend(item.tags) + if getattr(item, "rationale", None): + rationales.append(str(item.rationale)) + cnt = Counter(statuses) + tag_cnt = Counter(tags) + st.subheader("候选池聚合概览(已匿名化)") + col_a, col_b, col_c = st.columns(3) + col_a.metric("候选数", f"{len(candidates)}") + col_b.metric("平均评分", f"{np.mean(scores):.3f}" if scores else "-") + col_c.metric("中位评分", f"{np.median(scores):.3f}" if scores else "-") + + st.write("状态分布:") + st.json(dict(cnt)) + + if tag_cnt: + st.write("常见标签(示例):") + st.json(dict(tag_cnt.most_common(10))) + + if rationales: + st.write("汇总理由(节选,不含代码):") + seen = set() + excerpts = [] + for rationale in rationales: + text = rationale.strip() + if text and text not in seen: + seen.add(text) + excerpts.append(text) + if len(excerpts) >= 3: + break + for idx, excerpt in enumerate(excerpts, start=1): + st.markdown(f"**理由 {idx}:** {excerpt}") + + avg_score = float(np.mean(scores)) if scores else 0.0 + suggest_pct = max(0.0, min(0.3, 0.10 + (avg_score - 0.5) * 0.2)) + st.subheader("组合级建议(不指定标的)") + st.write( + f"基于候选池平均评分 {avg_score:.3f},建议今日用于新增买入的现金比例约为 {suggest_pct:.0%}。" + ) + st.write( + "建议分配思路:在候选池中挑选若干得分较高的标的按目标权重等比例分配,或以分批买入的方式分摊入场时点。" + ) + if st.button("生成组合级操作建议(仅输出,不执行)"): + st.success("已生成组合级建议(仅供参考)。") + st.write({ + "候选数": len(candidates), + "平均评分": avg_score, + "建议新增买入比例": f"{suggest_pct:.0%}", + }) + else: + st.info("所选交易日暂无候选投资池数据。") + except Exception: # noqa: BLE001 + LOGGER.exception("加载候选池聚合信息失败", extra=LOG_EXTRA) + st.error("加载候选池数据时发生错误。") + + +def _render_today_plan_symbol_view( + trade_date: str | int | date, + symbols: List[str], + query_params: Dict[str, List[str]], +) -> None: + default_ts = query_params.get("code", [symbols[0]])[0] + try: + default_ts_idx = symbols.index(default_ts) + except ValueError: + default_ts_idx = 0 + ts_code = st.selectbox("标的", symbols, index=default_ts_idx) + batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[]) + + if st.button("一键重评估所有标的", type="primary", width='stretch'): + with st.spinner("正在对所有标的进行重评估,请稍候..."): + try: + trade_date_obj: Optional[date] = None + try: + trade_date_obj = date.fromisoformat(str(trade_date)) + except Exception: + try: + trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() + except Exception: + pass + if trade_date_obj is None: + raise ValueError(f"无法解析交易日:{trade_date}") + + progress = st.progress(0.0) + changes_all: List[Dict[str, object]] = [] + success_count = 0 + error_count = 0 + + for idx, code in enumerate(symbols, start=1): + try: + with db_session(read_only=True) as conn: + before_rows = conn.execute( + "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", + (trade_date, code), + ).fetchall() + before_map = {row["agent"]: row["action"] for row in before_rows} + + cfg = BtConfig( + id="reeval_ui_all", + name="UI All Re-eval", + start_date=trade_date_obj, + end_date=trade_date_obj, + universe=[code], + params={}, + ) + engine = BacktestEngine(cfg) + state = PortfolioState() + engine.simulate_day(trade_date_obj, state) + + with db_session(read_only=True) as conn: + after_rows = conn.execute( + "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", + (trade_date, code), + ).fetchall() + for row in after_rows: + agent = row["agent"] + new_action = row["action"] + old_action = before_map.get(agent) + if new_action != old_action: + changes_all.append( + {"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action} + ) + success_count += 1 + except Exception: # noqa: BLE001 + LOGGER.exception("重评估 %s 失败", code, extra=LOG_EXTRA) + error_count += 1 + + progress.progress(idx / len(symbols)) + + if error_count > 0: + st.error(f"一键重评估完成:成功 {success_count} 个,失败 {error_count} 个") + else: + st.success(f"一键重评估完成:所有 {success_count} 个标的重评估成功") + + if changes_all: + st.write("检测到以下动作变更:") + st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch') + + st.rerun() + except Exception as exc: # noqa: BLE001 + LOGGER.exception("一键重评估失败", extra=LOG_EXTRA) + st.error(f"一键重评估执行过程中发生错误:{exc}") + + set_query_params(date=str(trade_date), code=str(ts_code)) + + with db_session(read_only=True) as conn: + rows = conn.execute( + """ + SELECT agent, action, utils, feasible, weight + FROM agent_utils + WHERE trade_date = ? AND ts_code = ? + ORDER BY CASE WHEN agent = 'global' THEN 1 ELSE 0 END, agent + """, + (trade_date, ts_code), + ).fetchall() + + if not rows: + st.info("未查询到详细决策记录,稍后再试。") + return + + try: + feasible_actions = json.loads(rows[0]["feasible"] or "[]") + except (KeyError, TypeError, json.JSONDecodeError): + feasible_actions = [] + + global_info = None + dept_records: List[Dict[str, object]] = [] + dept_details: Dict[str, Dict[str, object]] = {} + agent_records: List[Dict[str, object]] = [] + + for item in rows: + agent_name = item["agent"] + action = item["action"] + weight = float(item["weight"] or 0.0) + try: + utils = json.loads(item["utils"] or "{}") + except json.JSONDecodeError: + utils = {} + + if agent_name == "global": + global_info = { + "action": action, + "confidence": float(utils.get("_confidence", 0.0)), + "target_weight": float(utils.get("_target_weight", 0.0)), + "department_votes": utils.get("_department_votes", {}), + "requires_review": bool(utils.get("_requires_review", False)), + "scope_values": utils.get("_scope_values", {}), + "close_series": utils.get("_close_series", []), + "turnover_series": utils.get("_turnover_series", []), + "department_supplements": utils.get("_department_supplements", {}), + "department_dialogue": utils.get("_department_dialogue", {}), + "department_telemetry": utils.get("_department_telemetry", {}), + } + continue + + if agent_name.startswith("dept_"): + code = agent_name.split("dept_", 1)[-1] + signals = utils.get("_signals", []) + risks = utils.get("_risks", []) + supplements = utils.get("_supplements", []) + dialogue = utils.get("_dialogue", []) + telemetry = utils.get("_telemetry", {}) + dept_records.append( + { + "部门": code, + "行动": action, + "信心": float(utils.get("_confidence", 0.0)), + "权重": weight, + "摘要": utils.get("_summary", ""), + "核心信号": ";".join(signals) if isinstance(signals, list) else signals, + "风险提示": ";".join(risks) if isinstance(risks, list) else risks, + "补充次数": len(supplements) if isinstance(supplements, list) else 0, + } + ) + dept_details[code] = { + "supplements": supplements if isinstance(supplements, list) else [], + "dialogue": dialogue if isinstance(dialogue, list) else [], + "summary": utils.get("_summary", ""), + "signals": signals, + "risks": risks, + "telemetry": telemetry if isinstance(telemetry, dict) else {}, + } + else: + score_map = { + key: float(val) + for key, val in utils.items() + if not str(key).startswith("_") + } + agent_records.append( + { + "代理": agent_name, + "建议动作": action, + "权重": weight, + "SELL": score_map.get("SELL", 0.0), + "HOLD": score_map.get("HOLD", 0.0), + "BUY_S": score_map.get("BUY_S", 0.0), + "BUY_M": score_map.get("BUY_M", 0.0), + "BUY_L": score_map.get("BUY_L", 0.0), + } + ) + + if feasible_actions: + st.caption(f"可行操作集合:{', '.join(feasible_actions)}") + + st.subheader("全局策略") + if global_info: + col1, col2, col3 = st.columns(3) + col1.metric("最终行动", global_info["action"]) + col2.metric("信心", f"{global_info['confidence']:.2f}") + col3.metric("目标权重", f"{global_info['target_weight']:+.2%}") + if global_info["department_votes"]: + st.json(global_info["department_votes"]) + if global_info["requires_review"]: + st.warning("部门分歧较大,已标记为需人工复核。") + with st.expander("基础上下文数据", expanded=False): + scope = global_info.get("scope_values") or {} + close_series = global_info.get("close_series") or [] + turnover_series = global_info.get("turnover_series") or [] + st.write("最新字段:") + if scope: + st.json(scope) + st.download_button( + "下载字段(JSON)", + data=json.dumps(scope, ensure_ascii=False, indent=2), + file_name=f"{ts_code}_{trade_date}_scope.json", + mime="application/json", + key="dl_scope_json", + ) + if close_series: + st.write("收盘价时间序列 (最近窗口):") + st.json(close_series) + try: + import csv + import io + + buf = io.StringIO() + writer = csv.writer(buf) + writer.writerow(["trade_date", "close"]) + for dt_val, val in close_series: + writer.writerow([dt_val, val]) + st.download_button( + "下载收盘价(CSV)", + data=buf.getvalue(), + file_name=f"{ts_code}_{trade_date}_close_series.csv", + mime="text/csv", + key="dl_close_csv", + ) + except Exception: # noqa: BLE001 + pass + if turnover_series: + st.write("换手率时间序列 (最近窗口):") + st.json(turnover_series) + dept_sup = global_info.get("department_supplements") or {} + dept_dialogue = global_info.get("department_dialogue") or {} + dept_telemetry = global_info.get("department_telemetry") or {} + if dept_sup or dept_dialogue: + with st.expander("部门补数与对话记录", expanded=False): + if dept_sup: + st.write("补充数据:") + st.json(dept_sup) + if dept_dialogue: + st.write("对话片段:") + st.json(dept_dialogue) + if dept_telemetry: + with st.expander("部门 LLM 元数据", expanded=False): + st.json(dept_telemetry) + else: + st.info("暂未写入全局策略摘要。") + + st.subheader("部门意见") + if dept_records: + keyword = st.text_input("筛选摘要/信号关键词", value="") + filtered = dept_records + if keyword.strip(): + kw = keyword.strip() + filtered = [ + item + for item in dept_records + if kw in str(item.get("摘要", "")) or kw in str(item.get("核心信号", "")) + ] + min_conf = st.slider("最低信心过滤", 0.0, 1.0, 0.0, 0.05) + sort_col = st.selectbox("排序列", ["信心", "权重"], index=0) + filtered = [row for row in filtered if float(row.get("信心", 0.0)) >= min_conf] + filtered = sorted(filtered, key=lambda r: float(r.get(sort_col, 0.0)), reverse=True) + dept_df = pd.DataFrame(filtered) + st.dataframe(dept_df, width='stretch', hide_index=True) + try: + st.download_button( + "下载部门意见(CSV)", + data=dept_df.to_csv(index=False), + file_name=f"{trade_date}_{ts_code}_departments.csv", + mime="text/csv", + key="dl_dept_csv", + ) + except Exception: # noqa: BLE001 + pass + for code, details in dept_details.items(): + with st.expander(f"{code} 补充详情", expanded=False): + supplements = details.get("supplements", []) + dialogue = details.get("dialogue", []) + if supplements: + st.write("补充数据:") + st.json(supplements) + else: + st.caption("无补充数据请求。") + if dialogue: + st.write("对话记录:") + for idx, line in enumerate(dialogue, start=1): + st.markdown(f"**回合 {idx}:** {line}") + else: + st.caption("无额外对话。") + telemetry = details.get("telemetry") or {} + if telemetry: + st.write("LLM 元数据:") + st.json(telemetry) + else: + st.info("暂无部门记录。") + + st.subheader("代理评分") + if agent_records: + sort_agent_by = st.selectbox( + "代理排序", + ["权重", "SELL", "HOLD", "BUY_S", "BUY_M", "BUY_L"], + index=1, + ) + agent_df = pd.DataFrame(agent_records) + if sort_agent_by in agent_df.columns: + agent_df = agent_df.sort_values(sort_agent_by, ascending=False) + st.dataframe(agent_df, width='stretch', hide_index=True) + try: + st.download_button( + "下载代理评分(CSV)", + data=agent_df.to_csv(index=False), + file_name=f"{trade_date}_{ts_code}_agents.csv", + mime="text/csv", + key="dl_agent_csv", + ) + except Exception: # noqa: BLE001 + pass + else: + st.info("暂无基础代理评分。") + + st.divider() + st.subheader("相关新闻") + try: + with db_session(read_only=True) as conn: + try: + trade_date_obj = date.fromisoformat(str(trade_date)) + except Exception: + try: + trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() + except Exception: + trade_date_obj = date.today() - timedelta(days=7) + + news_query = """ + SELECT id, title, source, pub_time, sentiment, heat, entities + FROM news + WHERE ts_code = ? AND pub_time >= ? + ORDER BY pub_time DESC + LIMIT 10 + """ + seven_days_ago = (trade_date_obj - timedelta(days=7)).strftime("%Y-%m-%d") + news_rows = conn.execute(news_query, (ts_code, seven_days_ago)).fetchall() + + if news_rows: + news_data = [] + for row in news_rows: + entities_info = {} + try: + if row["entities"]: + entities_info = json.loads(row["entities"]) + except (json.JSONDecodeError, TypeError): + pass + + news_item = { + "标题": row["title"], + "来源": row["source"], + "发布时间": row["pub_time"], + "情感指数": f"{row['sentiment']:.2f}" if row["sentiment"] is not None else "-", + "热度评分": f"{row['heat']:.2f}" if row["heat"] is not None else "-", + } + + industries = entities_info.get("industries", []) + if industries: + news_item["相关行业"] = "、".join(industries[:3]) + + news_data.append(news_item) + + news_df = pd.DataFrame(news_data) + for col in news_df.columns: + news_df[col] = news_df[col].astype(str) + st.dataframe(news_df, width='stretch', hide_index=True) + + st.write("详细新闻内容:") + for idx, row in enumerate(news_rows): + with st.expander(f"{idx+1}. {row['title']}", expanded=False): + st.write(f"**来源:** {row['source']}") + st.write(f"**发布时间:** {row['pub_time']}") + + entities_info = {} + try: + if row["entities"]: + entities_info = json.loads(row["entities"]) + except (json.JSONDecodeError, TypeError): + pass + + sentiment_display = f"{row['sentiment']:.2f}" if row["sentiment"] is not None else "-" + heat_display = f"{row['heat']:.2f}" if row["heat"] is not None else "-" + st.write(f"**情感指数:** {sentiment_display} | **热度评分:** {heat_display}") + + industries = entities_info.get("industries", []) + if industries: + st.write(f"**相关行业:** {'、'.join(industries)}") + + important_keywords = entities_info.get("important_keywords", []) + if important_keywords: + st.write(f"**重要关键词:** {'、'.join(important_keywords)}") + + url = entities_info.get("source_url", "") + if url: + st.markdown(f"[查看原文]({url})", unsafe_allow_html=True) + else: + st.info(f"近7天内暂无关于 {ts_code} 的新闻。") + except Exception as exc: # noqa: BLE001 + LOGGER.exception("获取新闻数据失败", extra=LOG_EXTRA) + st.error(f"获取新闻数据时发生错误:{exc}") + + st.divider() + st.info("投资池与仓位概览已移至单独页面。请在侧边或页面导航中选择“投资池/仓位”以查看详细信息。") + + st.divider() + st.subheader("策略重评估") + st.caption("对当前选中的交易日与标的,立即触发一次策略评估并回写 agent_utils。") + cols_re = st.columns([1, 1]) + if cols_re[0].button("对该标的重评估", key="reevaluate_current_symbol"): + with st.spinner("正在重评估..."): + try: + trade_date_obj: Optional[date] = None + try: + trade_date_obj = date.fromisoformat(str(trade_date)) + except Exception: + try: + trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() + except Exception: + pass + if trade_date_obj is None: + raise ValueError(f"无法解析交易日:{trade_date}") + with db_session(read_only=True) as conn: + before_rows = conn.execute( + """ + SELECT agent, action, utils FROM agent_utils + WHERE trade_date = ? AND ts_code = ? + """, + (trade_date, ts_code), + ).fetchall() + before_map = {row["agent"]: (row["action"], row["utils"]) for row in before_rows} + cfg = BtConfig( + id="reeval_ui", + name="UI Re-evaluation", + start_date=trade_date_obj, + end_date=trade_date_obj, + universe=[ts_code], + params={}, + ) + engine = BacktestEngine(cfg) + state = PortfolioState() + engine.simulate_day(trade_date_obj, state) + with db_session(read_only=True) as conn: + after_rows = conn.execute( + """ + SELECT agent, action, utils FROM agent_utils + WHERE trade_date = ? AND ts_code = ? + """, + (trade_date, ts_code), + ).fetchall() + changes = [] + for row in after_rows: + agent = row["agent"] + new_action = row["action"] + old_action, _old_utils = before_map.get(agent, (None, None)) + if new_action != old_action: + changes.append({"代理": agent, "原动作": old_action, "新动作": new_action}) + if changes: + st.success("重评估完成,检测到动作变更:") + st.dataframe(pd.DataFrame(changes), hide_index=True, width='stretch') + else: + st.success("重评估完成,无动作变更。") + st.rerun() + except Exception as exc: # noqa: BLE001 + LOGGER.exception("重评估失败", extra=LOG_EXTRA) + st.error(f"重评估失败:{exc}") + if cols_re[1].button("批量重评估(所选)", key="reevaluate_batch", disabled=not batch_symbols): + with st.spinner("批量重评估执行中..."): + try: + trade_date_obj: Optional[date] = None + try: + trade_date_obj = date.fromisoformat(str(trade_date)) + except Exception: + try: + trade_date_obj = datetime.strptime(str(trade_date), "%Y%m%d").date() + except Exception: + pass + if trade_date_obj is None: + raise ValueError(f"无法解析交易日:{trade_date}") + progress = st.progress(0.0) + changes_all: List[Dict[str, object]] = [] + for idx, code in enumerate(batch_symbols, start=1): + with db_session(read_only=True) as conn: + before_rows = conn.execute( + "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", + (trade_date, code), + ).fetchall() + before_map = {row["agent"]: row["action"] for row in before_rows} + cfg = BtConfig( + id="reeval_ui_batch", + name="UI Batch Re-eval", + start_date=trade_date_obj, + end_date=trade_date_obj, + universe=[code], + params={}, + ) + engine = BacktestEngine(cfg) + state = PortfolioState() + engine.simulate_day(trade_date_obj, state) + with db_session(read_only=True) as conn: + after_rows = conn.execute( + "SELECT agent, action FROM agent_utils WHERE trade_date = ? AND ts_code = ?", + (trade_date, code), + ).fetchall() + for row in after_rows: + agent = row["agent"] + new_action = row["action"] + old_action = before_map.get(agent) + if new_action != old_action: + changes_all.append({"代码": code, "代理": agent, "原动作": old_action, "新动作": new_action}) + progress.progress(idx / max(1, len(batch_symbols))) + st.success("批量重评估完成。") + if changes_all: + st.dataframe(pd.DataFrame(changes_all), hide_index=True, width='stretch') + st.rerun() + except Exception as exc: # noqa: BLE001 + LOGGER.exception("批量重评估失败", extra=LOG_EXTRA) + st.error(f"批量重评估失败:{exc}")