From 35632203857507dd20c193106ef7c8caefed9479 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 11 Oct 2025 20:40:06 +0800 Subject: [PATCH] add stock name, industry and timestamp to investment pool schema --- app/backtest/engine.py | 45 +++++++++++++++++++++- app/ui/views/backtest.py | 14 ++++++- app/ui/views/stock_eval.py | 40 +++++++++++++++++++ app/ui/views/today.py | 40 ++++++++++++++++--- app/utils/portfolio.py | 79 +++++++++++++++++++++++++------------- 5 files changed, 183 insertions(+), 35 deletions(-) diff --git a/app/backtest/engine.py b/app/backtest/engine.py index 36ea420..8a08188 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -2,6 +2,7 @@ from __future__ import annotations import json +import sqlite3 from collections import defaultdict from dataclasses import dataclass, field from datetime import date, datetime @@ -970,12 +971,17 @@ class BacktestEngine: for code, dept in decision.department_decisions.items() } + stock_info = self.data_broker.get_stock_info(context.ts_code, context.trade_date) + name = stock_info.get("name") if stock_info else None + industry = stock_info.get("industry") if stock_info else None + with db_session() as conn: + self._ensure_investment_pool_columns(conn) conn.execute( """ INSERT OR REPLACE INTO investment_pool - (trade_date, ts_code, score, status, rationale, tags, metadata) - VALUES (?, ?, ?, ?, ?, ?, ?) + (trade_date, ts_code, score, status, rationale, tags, metadata, name, industry) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( context.trade_date, @@ -985,9 +991,44 @@ class BacktestEngine: summary or None, json.dumps(_department_tags(decision), ensure_ascii=False), json.dumps(metadata, ensure_ascii=False), + name, + industry, ), ) + @staticmethod + def _ensure_investment_pool_columns(conn: sqlite3.Connection) -> None: + try: + info = conn.execute("PRAGMA table_info(investment_pool)").fetchall() + except sqlite3.Error: + return + + columns = { + (row[1] if not isinstance(row, sqlite3.Row) else row["name"]) + for row in info + if row is not None + } + if "name" not in columns: + try: + conn.execute("ALTER TABLE investment_pool ADD COLUMN name TEXT") + except sqlite3.Error: + pass + if "industry" not in columns: + try: + conn.execute("ALTER TABLE investment_pool ADD COLUMN industry TEXT") + except sqlite3.Error: + pass + if "created_at" not in columns: + try: + conn.execute( + "ALTER TABLE investment_pool ADD COLUMN created_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now'))" + ) + except sqlite3.Error: + try: + conn.execute("ALTER TABLE investment_pool ADD COLUMN created_at TEXT") + except sqlite3.Error: + pass + def _persist_portfolio( self, trade_date: str, diff --git a/app/ui/views/backtest.py b/app/ui/views/backtest.py index 9c38c42..da10714 100644 --- a/app/ui/views/backtest.py +++ b/app/ui/views/backtest.py @@ -26,6 +26,7 @@ from app.llm.templates import TemplateRegistry from app.utils import alerts from app.utils.config import get_config, save_config from app.utils.tuning import log_tuning_result +from app.utils.portfolio import list_investment_pool from app.utils.db import db_session @@ -59,7 +60,18 @@ def render_backtest_review() -> None: col1, col2 = st.columns(2) start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date") end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date") - universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe") + + latest_candidates = list_investment_pool(limit=50) + candidate_codes = [item.ts_code for item in latest_candidates] + default_universe = ",".join(candidate_codes) if candidate_codes else "000001.SZ" + universe_text = st.text_input( + "股票列表(逗号分隔)", + value=default_universe, + key="bt_universe", + help="默认载入最新候选池,如需自定义可直接编辑。", + ) + if candidate_codes: + st.caption(f"候选池载入 {len(candidate_codes)} 个标的:{'、'.join(candidate_codes[:10])}{'…' if len(candidate_codes)>10 else ''}") col_target, col_stop, col_hold, col_cap = st.columns(4) target = col_target.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f", key="bt_target") stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f", key="bt_stop") diff --git a/app/ui/views/stock_eval.py b/app/ui/views/stock_eval.py index 3c6651c..0c6b3e9 100644 --- a/app/ui/views/stock_eval.py +++ b/app/ui/views/stock_eval.py @@ -2,6 +2,7 @@ from datetime import date, datetime, timedelta from typing import Dict, List, Optional import json +import sqlite3 import numpy as np import pandas as pd @@ -15,6 +16,44 @@ from app.utils.data_access import DataBroker from app.utils.db import db_session from app.utils.logging import get_logger +LOGGER = get_logger(__name__) +LOG_EXTRA = {"stage": "stock_eval"} + + +def _ensure_investment_pool_schema(conn: sqlite3.Connection) -> None: + """Ensure investment_pool table has latest optional columns.""" + try: + info = conn.execute("PRAGMA table_info(investment_pool)").fetchall() + except sqlite3.Error: + return + + columns = { + (row["name"] if isinstance(row, sqlite3.Row) else row[1]) + for row in info + if row is not None + } + + if "name" not in columns: + try: + conn.execute("ALTER TABLE investment_pool ADD COLUMN name TEXT") + except sqlite3.Error: + pass + if "industry" not in columns: + try: + conn.execute("ALTER TABLE investment_pool ADD COLUMN industry TEXT") + except sqlite3.Error: + pass + if "created_at" not in columns: + try: + conn.execute( + "ALTER TABLE investment_pool ADD COLUMN created_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now'))" + ) + except sqlite3.Error: + try: + conn.execute("ALTER TABLE investment_pool ADD COLUMN created_at TEXT") + except sqlite3.Error: + pass + def _get_latest_trading_date() -> date: """获取数据库中的最新交易日期""" @@ -510,6 +549,7 @@ def _add_to_stock_pool( ) with db_session() as conn: + _ensure_investment_pool_schema(conn) conn.execute("DELETE FROM investment_pool WHERE trade_date = ?", (trade_date,)) if payload: conn.executemany( diff --git a/app/ui/views/today.py b/app/ui/views/today.py index 5b4356c..f5f00dc 100644 --- a/app/ui/views/today.py +++ b/app/ui/views/today.py @@ -11,7 +11,7 @@ import pandas as pd import streamlit as st from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig -from app.utils.portfolio import list_investment_pool +from app.utils.portfolio import InvestmentCandidate, list_investment_pool from app.utils.db import db_session from app.ui.shared import ( @@ -175,24 +175,39 @@ def render_today_plan() -> None: ).fetchall() symbols = [row["ts_code"] for row in code_rows] + candidate_records = list_investment_pool(trade_date=trade_date) + if candidate_records: + st.caption( + f"候选池包含 {len(candidate_records)} 个标的:" + + "、".join(item.ts_code for item in candidate_records[:12]) + + ("…" if len(candidate_records) > 12 else "") + ) + + if candidate_records: + candidate_codes = [item.ts_code for item in candidate_records] + symbols = list(dict.fromkeys(candidate_codes + symbols)) + detail_tab, assistant_tab = st.tabs(["标的详情", "投资助理模式"]) with assistant_tab: - _render_today_plan_assistant_view(trade_date) + _render_today_plan_assistant_view(trade_date, candidate_records) with detail_tab: if not symbols: st.info("所选交易日暂无 agent_utils 记录。") else: - _render_today_plan_symbol_view(trade_date, symbols, query) + _render_today_plan_symbol_view(trade_date, symbols, query, candidate_records) -def _render_today_plan_assistant_view(trade_date: str | int | date) -> None: +def _render_today_plan_assistant_view( + trade_date: str | int | date, + candidate_records: List[InvestmentCandidate], +) -> None: # 确保日期格式为字符串 if isinstance(trade_date, date): trade_date = trade_date.strftime("%Y%m%d") st.info("已开启投资助理模式:以下内容为组合级(去标的)建议,不包含任何具体标的代码。") try: - candidates = list_investment_pool(trade_date=trade_date) + candidates = candidate_records or list_investment_pool(trade_date=trade_date) if candidates: scores = [float(item.score or 0.0) for item in candidates] statuses = [item.status or "UNKNOWN" for item in candidates] @@ -259,6 +274,7 @@ def _render_today_plan_symbol_view( trade_date: str | int | date, symbols: List[str], query_params: Dict[str, List[str]], + candidate_records: List[InvestmentCandidate], ) -> None: default_ts = query_params.get("code", [symbols[0]])[0] try: @@ -266,7 +282,9 @@ def _render_today_plan_symbol_view( except ValueError: default_ts_idx = 0 ts_code = st.selectbox("标的", symbols, index=default_ts_idx) - batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[]) + candidate_code_set = {item.ts_code for item in candidate_records} + default_batch = [code for code in symbols if code in candidate_code_set] + batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=default_batch[:10]) if st.button("一键重评估所有标的", type="primary", width='stretch'): with st.spinner("正在对所有标的进行重评估,请稍候..."): @@ -318,6 +336,16 @@ def _render_today_plan_symbol_view( st.info("未查询到详细决策记录,稍后再试。") return + candidate_map = {item.ts_code: item for item in candidate_records} + candidate_info = candidate_map.get(ts_code) + if candidate_info: + info_cols = st.columns(3) + info_cols[0].metric("候选评分", f"{(candidate_info.score or 0):.3f}") + info_cols[1].metric("状态", candidate_info.status or "-") + info_cols[2].metric("更新时间", candidate_info.created_at or "-") + if candidate_info.rationale: + st.caption(f"候选理由:{candidate_info.rationale}") + try: feasible_actions = json.loads(rows[0]["feasible"] or "[]") except (KeyError, TypeError, json.JSONDecodeError): diff --git a/app/utils/portfolio.py b/app/utils/portfolio.py index 3748b25..021b31a 100644 --- a/app/utils/portfolio.py +++ b/app/utils/portfolio.py @@ -34,6 +34,7 @@ class InvestmentCandidate: metadata: Dict[str, Any] name: Optional[str] = None industry: Optional[str] = None + created_at: Optional[str] = None def list_investment_pool( @@ -44,31 +45,55 @@ def list_investment_pool( ) -> List[InvestmentCandidate]: """Return investment candidates for the given trade date (latest if None).""" - query = [ - "SELECT trade_date, ts_code, score, status, rationale, tags, metadata, name, industry", - "FROM investment_pool", - ] - params: List[Any] = [] - - if trade_date: - query.append("WHERE trade_date = ?") - params.append(trade_date) - else: - query.append( - "WHERE trade_date = (SELECT MAX(trade_date) FROM investment_pool)" - ) - - if status: - placeholders = ", ".join("?" for _ in status) - query.append(f"AND status IN ({placeholders})") - params.extend(list(status)) - - query.append("ORDER BY (score IS NULL), score DESC, ts_code") - query.append("LIMIT ?") - params.append(int(limit)) - - sql = "\n".join(query) with db_session(read_only=True) as conn: + try: + info = conn.execute("PRAGMA table_info(investment_pool)").fetchall() + except Exception: # noqa: BLE001 + LOGGER.exception("无法读取 investment_pool 结构", extra=LOG_EXTRA) + return [] + + available_columns = { + (row[1] if not isinstance(row, dict) else row.get("name")) + for row in info + } + + select_columns = [ + "trade_date", + "ts_code", + "score", + "status", + "rationale", + "tags", + "metadata", + ] + optional_columns = [col for col in ("name", "industry", "created_at") if col in available_columns] + select_columns.extend(optional_columns) + + column_clause = ", ".join(select_columns) + query = [ + f"SELECT {column_clause}", + "FROM investment_pool", + ] + params: List[Any] = [] + + if trade_date: + query.append("WHERE trade_date = ?") + params.append(trade_date) + else: + query.append( + "WHERE trade_date = (SELECT MAX(trade_date) FROM investment_pool)" + ) + + if status: + placeholders = ", ".join("?" for _ in status) + query.append(f"AND status IN ({placeholders})") + params.extend(list(status)) + + query.append("ORDER BY (score IS NULL), score DESC, ts_code") + query.append("LIMIT ?") + params.append(int(limit)) + + sql = "\n".join(query) try: rows = conn.execute(sql, params).fetchall() except Exception: # noqa: BLE001 @@ -77,6 +102,7 @@ def list_investment_pool( candidates: List[InvestmentCandidate] = [] for row in rows: + row_keys = set(row.keys()) if hasattr(row, "keys") else set() candidates.append( InvestmentCandidate( trade_date=row["trade_date"], @@ -86,8 +112,9 @@ def list_investment_pool( rationale=row["rationale"], tags=list(_loads_or_default(row["tags"], [])), metadata=dict(_loads_or_default(row["metadata"], {})), - name=row["name"], - industry=row["industry"], + name=row["name"] if "name" in row_keys else None, + industry=row["industry"] if "industry" in row_keys else None, + created_at=row["created_at"] if "created_at" in row_keys else None, ) ) return candidates