add stock name, industry and timestamp to investment pool schema

This commit is contained in:
Your Name 2025-10-11 20:40:06 +08:00
parent 3fc563e72a
commit 3563220385
5 changed files with 183 additions and 35 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
import sqlite3
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import date, datetime from datetime import date, datetime
@ -970,12 +971,17 @@ class BacktestEngine:
for code, dept in decision.department_decisions.items() for code, dept in decision.department_decisions.items()
} }
stock_info = self.data_broker.get_stock_info(context.ts_code, context.trade_date)
name = stock_info.get("name") if stock_info else None
industry = stock_info.get("industry") if stock_info else None
with db_session() as conn: with db_session() as conn:
self._ensure_investment_pool_columns(conn)
conn.execute( conn.execute(
""" """
INSERT OR REPLACE INTO investment_pool INSERT OR REPLACE INTO investment_pool
(trade_date, ts_code, score, status, rationale, tags, metadata) (trade_date, ts_code, score, status, rationale, tags, metadata, name, industry)
VALUES (?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", """,
( (
context.trade_date, context.trade_date,
@ -985,9 +991,44 @@ class BacktestEngine:
summary or None, summary or None,
json.dumps(_department_tags(decision), ensure_ascii=False), json.dumps(_department_tags(decision), ensure_ascii=False),
json.dumps(metadata, ensure_ascii=False), json.dumps(metadata, ensure_ascii=False),
name,
industry,
), ),
) )
@staticmethod
def _ensure_investment_pool_columns(conn: sqlite3.Connection) -> None:
try:
info = conn.execute("PRAGMA table_info(investment_pool)").fetchall()
except sqlite3.Error:
return
columns = {
(row[1] if not isinstance(row, sqlite3.Row) else row["name"])
for row in info
if row is not None
}
if "name" not in columns:
try:
conn.execute("ALTER TABLE investment_pool ADD COLUMN name TEXT")
except sqlite3.Error:
pass
if "industry" not in columns:
try:
conn.execute("ALTER TABLE investment_pool ADD COLUMN industry TEXT")
except sqlite3.Error:
pass
if "created_at" not in columns:
try:
conn.execute(
"ALTER TABLE investment_pool ADD COLUMN created_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now'))"
)
except sqlite3.Error:
try:
conn.execute("ALTER TABLE investment_pool ADD COLUMN created_at TEXT")
except sqlite3.Error:
pass
def _persist_portfolio( def _persist_portfolio(
self, self,
trade_date: str, trade_date: str,

View File

@ -26,6 +26,7 @@ from app.llm.templates import TemplateRegistry
from app.utils import alerts from app.utils import alerts
from app.utils.config import get_config, save_config from app.utils.config import get_config, save_config
from app.utils.tuning import log_tuning_result from app.utils.tuning import log_tuning_result
from app.utils.portfolio import list_investment_pool
from app.utils.db import db_session from app.utils.db import db_session
@ -59,7 +60,18 @@ def render_backtest_review() -> None:
col1, col2 = st.columns(2) col1, col2 = st.columns(2)
start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date") start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date")
end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date") end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date")
universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe")
latest_candidates = list_investment_pool(limit=50)
candidate_codes = [item.ts_code for item in latest_candidates]
default_universe = ",".join(candidate_codes) if candidate_codes else "000001.SZ"
universe_text = st.text_input(
"股票列表(逗号分隔)",
value=default_universe,
key="bt_universe",
help="默认载入最新候选池,如需自定义可直接编辑。",
)
if candidate_codes:
st.caption(f"候选池载入 {len(candidate_codes)} 个标的:{''.join(candidate_codes[:10])}{'' if len(candidate_codes)>10 else ''}")
col_target, col_stop, col_hold, col_cap = st.columns(4) col_target, col_stop, col_hold, col_cap = st.columns(4)
target = col_target.number_input("目标收益0.035 表示 3.5%", value=0.035, step=0.005, format="%.3f", key="bt_target") target = col_target.number_input("目标收益0.035 表示 3.5%", value=0.035, step=0.005, format="%.3f", key="bt_target")
stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%", value=-0.015, step=0.005, format="%.3f", key="bt_stop") stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%", value=-0.015, step=0.005, format="%.3f", key="bt_stop")

View File

@ -2,6 +2,7 @@
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
from typing import Dict, List, Optional from typing import Dict, List, Optional
import json import json
import sqlite3
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -15,6 +16,44 @@ from app.utils.data_access import DataBroker
from app.utils.db import db_session from app.utils.db import db_session
from app.utils.logging import get_logger from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "stock_eval"}
def _ensure_investment_pool_schema(conn: sqlite3.Connection) -> None:
"""Ensure investment_pool table has latest optional columns."""
try:
info = conn.execute("PRAGMA table_info(investment_pool)").fetchall()
except sqlite3.Error:
return
columns = {
(row["name"] if isinstance(row, sqlite3.Row) else row[1])
for row in info
if row is not None
}
if "name" not in columns:
try:
conn.execute("ALTER TABLE investment_pool ADD COLUMN name TEXT")
except sqlite3.Error:
pass
if "industry" not in columns:
try:
conn.execute("ALTER TABLE investment_pool ADD COLUMN industry TEXT")
except sqlite3.Error:
pass
if "created_at" not in columns:
try:
conn.execute(
"ALTER TABLE investment_pool ADD COLUMN created_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now'))"
)
except sqlite3.Error:
try:
conn.execute("ALTER TABLE investment_pool ADD COLUMN created_at TEXT")
except sqlite3.Error:
pass
def _get_latest_trading_date() -> date: def _get_latest_trading_date() -> date:
"""获取数据库中的最新交易日期""" """获取数据库中的最新交易日期"""
@ -510,6 +549,7 @@ def _add_to_stock_pool(
) )
with db_session() as conn: with db_session() as conn:
_ensure_investment_pool_schema(conn)
conn.execute("DELETE FROM investment_pool WHERE trade_date = ?", (trade_date,)) conn.execute("DELETE FROM investment_pool WHERE trade_date = ?", (trade_date,))
if payload: if payload:
conn.executemany( conn.executemany(

View File

@ -11,7 +11,7 @@ import pandas as pd
import streamlit as st import streamlit as st
from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig
from app.utils.portfolio import list_investment_pool from app.utils.portfolio import InvestmentCandidate, list_investment_pool
from app.utils.db import db_session from app.utils.db import db_session
from app.ui.shared import ( from app.ui.shared import (
@ -175,24 +175,39 @@ def render_today_plan() -> None:
).fetchall() ).fetchall()
symbols = [row["ts_code"] for row in code_rows] symbols = [row["ts_code"] for row in code_rows]
candidate_records = list_investment_pool(trade_date=trade_date)
if candidate_records:
st.caption(
f"候选池包含 {len(candidate_records)} 个标的:"
+ "".join(item.ts_code for item in candidate_records[:12])
+ ("" if len(candidate_records) > 12 else "")
)
if candidate_records:
candidate_codes = [item.ts_code for item in candidate_records]
symbols = list(dict.fromkeys(candidate_codes + symbols))
detail_tab, assistant_tab = st.tabs(["标的详情", "投资助理模式"]) detail_tab, assistant_tab = st.tabs(["标的详情", "投资助理模式"])
with assistant_tab: with assistant_tab:
_render_today_plan_assistant_view(trade_date) _render_today_plan_assistant_view(trade_date, candidate_records)
with detail_tab: with detail_tab:
if not symbols: if not symbols:
st.info("所选交易日暂无 agent_utils 记录。") st.info("所选交易日暂无 agent_utils 记录。")
else: else:
_render_today_plan_symbol_view(trade_date, symbols, query) _render_today_plan_symbol_view(trade_date, symbols, query, candidate_records)
def _render_today_plan_assistant_view(trade_date: str | int | date) -> None: def _render_today_plan_assistant_view(
trade_date: str | int | date,
candidate_records: List[InvestmentCandidate],
) -> None:
# 确保日期格式为字符串 # 确保日期格式为字符串
if isinstance(trade_date, date): if isinstance(trade_date, date):
trade_date = trade_date.strftime("%Y%m%d") trade_date = trade_date.strftime("%Y%m%d")
st.info("已开启投资助理模式:以下内容为组合级(去标的)建议,不包含任何具体标的代码。") st.info("已开启投资助理模式:以下内容为组合级(去标的)建议,不包含任何具体标的代码。")
try: try:
candidates = list_investment_pool(trade_date=trade_date) candidates = candidate_records or list_investment_pool(trade_date=trade_date)
if candidates: if candidates:
scores = [float(item.score or 0.0) for item in candidates] scores = [float(item.score or 0.0) for item in candidates]
statuses = [item.status or "UNKNOWN" for item in candidates] statuses = [item.status or "UNKNOWN" for item in candidates]
@ -259,6 +274,7 @@ def _render_today_plan_symbol_view(
trade_date: str | int | date, trade_date: str | int | date,
symbols: List[str], symbols: List[str],
query_params: Dict[str, List[str]], query_params: Dict[str, List[str]],
candidate_records: List[InvestmentCandidate],
) -> None: ) -> None:
default_ts = query_params.get("code", [symbols[0]])[0] default_ts = query_params.get("code", [symbols[0]])[0]
try: try:
@ -266,7 +282,9 @@ def _render_today_plan_symbol_view(
except ValueError: except ValueError:
default_ts_idx = 0 default_ts_idx = 0
ts_code = st.selectbox("标的", symbols, index=default_ts_idx) ts_code = st.selectbox("标的", symbols, index=default_ts_idx)
batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[]) candidate_code_set = {item.ts_code for item in candidate_records}
default_batch = [code for code in symbols if code in candidate_code_set]
batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=default_batch[:10])
if st.button("一键重评估所有标的", type="primary", width='stretch'): if st.button("一键重评估所有标的", type="primary", width='stretch'):
with st.spinner("正在对所有标的进行重评估,请稍候..."): with st.spinner("正在对所有标的进行重评估,请稍候..."):
@ -318,6 +336,16 @@ def _render_today_plan_symbol_view(
st.info("未查询到详细决策记录,稍后再试。") st.info("未查询到详细决策记录,稍后再试。")
return return
candidate_map = {item.ts_code: item for item in candidate_records}
candidate_info = candidate_map.get(ts_code)
if candidate_info:
info_cols = st.columns(3)
info_cols[0].metric("候选评分", f"{(candidate_info.score or 0):.3f}")
info_cols[1].metric("状态", candidate_info.status or "-")
info_cols[2].metric("更新时间", candidate_info.created_at or "-")
if candidate_info.rationale:
st.caption(f"候选理由:{candidate_info.rationale}")
try: try:
feasible_actions = json.loads(rows[0]["feasible"] or "[]") feasible_actions = json.loads(rows[0]["feasible"] or "[]")
except (KeyError, TypeError, json.JSONDecodeError): except (KeyError, TypeError, json.JSONDecodeError):

View File

@ -34,6 +34,7 @@ class InvestmentCandidate:
metadata: Dict[str, Any] metadata: Dict[str, Any]
name: Optional[str] = None name: Optional[str] = None
industry: Optional[str] = None industry: Optional[str] = None
created_at: Optional[str] = None
def list_investment_pool( def list_investment_pool(
@ -44,31 +45,55 @@ def list_investment_pool(
) -> List[InvestmentCandidate]: ) -> List[InvestmentCandidate]:
"""Return investment candidates for the given trade date (latest if None).""" """Return investment candidates for the given trade date (latest if None)."""
query = [
"SELECT trade_date, ts_code, score, status, rationale, tags, metadata, name, industry",
"FROM investment_pool",
]
params: List[Any] = []
if trade_date:
query.append("WHERE trade_date = ?")
params.append(trade_date)
else:
query.append(
"WHERE trade_date = (SELECT MAX(trade_date) FROM investment_pool)"
)
if status:
placeholders = ", ".join("?" for _ in status)
query.append(f"AND status IN ({placeholders})")
params.extend(list(status))
query.append("ORDER BY (score IS NULL), score DESC, ts_code")
query.append("LIMIT ?")
params.append(int(limit))
sql = "\n".join(query)
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
try:
info = conn.execute("PRAGMA table_info(investment_pool)").fetchall()
except Exception: # noqa: BLE001
LOGGER.exception("无法读取 investment_pool 结构", extra=LOG_EXTRA)
return []
available_columns = {
(row[1] if not isinstance(row, dict) else row.get("name"))
for row in info
}
select_columns = [
"trade_date",
"ts_code",
"score",
"status",
"rationale",
"tags",
"metadata",
]
optional_columns = [col for col in ("name", "industry", "created_at") if col in available_columns]
select_columns.extend(optional_columns)
column_clause = ", ".join(select_columns)
query = [
f"SELECT {column_clause}",
"FROM investment_pool",
]
params: List[Any] = []
if trade_date:
query.append("WHERE trade_date = ?")
params.append(trade_date)
else:
query.append(
"WHERE trade_date = (SELECT MAX(trade_date) FROM investment_pool)"
)
if status:
placeholders = ", ".join("?" for _ in status)
query.append(f"AND status IN ({placeholders})")
params.extend(list(status))
query.append("ORDER BY (score IS NULL), score DESC, ts_code")
query.append("LIMIT ?")
params.append(int(limit))
sql = "\n".join(query)
try: try:
rows = conn.execute(sql, params).fetchall() rows = conn.execute(sql, params).fetchall()
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
@ -77,6 +102,7 @@ def list_investment_pool(
candidates: List[InvestmentCandidate] = [] candidates: List[InvestmentCandidate] = []
for row in rows: for row in rows:
row_keys = set(row.keys()) if hasattr(row, "keys") else set()
candidates.append( candidates.append(
InvestmentCandidate( InvestmentCandidate(
trade_date=row["trade_date"], trade_date=row["trade_date"],
@ -86,8 +112,9 @@ def list_investment_pool(
rationale=row["rationale"], rationale=row["rationale"],
tags=list(_loads_or_default(row["tags"], [])), tags=list(_loads_or_default(row["tags"], [])),
metadata=dict(_loads_or_default(row["metadata"], {})), metadata=dict(_loads_or_default(row["metadata"], {})),
name=row["name"], name=row["name"] if "name" in row_keys else None,
industry=row["industry"], industry=row["industry"] if "industry" in row_keys else None,
created_at=row["created_at"] if "created_at" in row_keys else None,
) )
) )
return candidates return candidates