add stock name, industry and timestamp to investment pool schema

This commit is contained in:
Your Name 2025-10-11 20:40:06 +08:00
parent 3fc563e72a
commit 3563220385
5 changed files with 183 additions and 35 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import json
import sqlite3
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import date, datetime
@ -970,12 +971,17 @@ class BacktestEngine:
for code, dept in decision.department_decisions.items()
}
stock_info = self.data_broker.get_stock_info(context.ts_code, context.trade_date)
name = stock_info.get("name") if stock_info else None
industry = stock_info.get("industry") if stock_info else None
with db_session() as conn:
self._ensure_investment_pool_columns(conn)
conn.execute(
"""
INSERT OR REPLACE INTO investment_pool
(trade_date, ts_code, score, status, rationale, tags, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?)
(trade_date, ts_code, score, status, rationale, tags, metadata, name, industry)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
context.trade_date,
@ -985,9 +991,44 @@ class BacktestEngine:
summary or None,
json.dumps(_department_tags(decision), ensure_ascii=False),
json.dumps(metadata, ensure_ascii=False),
name,
industry,
),
)
@staticmethod
def _ensure_investment_pool_columns(conn: sqlite3.Connection) -> None:
try:
info = conn.execute("PRAGMA table_info(investment_pool)").fetchall()
except sqlite3.Error:
return
columns = {
(row[1] if not isinstance(row, sqlite3.Row) else row["name"])
for row in info
if row is not None
}
if "name" not in columns:
try:
conn.execute("ALTER TABLE investment_pool ADD COLUMN name TEXT")
except sqlite3.Error:
pass
if "industry" not in columns:
try:
conn.execute("ALTER TABLE investment_pool ADD COLUMN industry TEXT")
except sqlite3.Error:
pass
if "created_at" not in columns:
try:
conn.execute(
"ALTER TABLE investment_pool ADD COLUMN created_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now'))"
)
except sqlite3.Error:
try:
conn.execute("ALTER TABLE investment_pool ADD COLUMN created_at TEXT")
except sqlite3.Error:
pass
def _persist_portfolio(
self,
trade_date: str,

View File

@ -26,6 +26,7 @@ from app.llm.templates import TemplateRegistry
from app.utils import alerts
from app.utils.config import get_config, save_config
from app.utils.tuning import log_tuning_result
from app.utils.portfolio import list_investment_pool
from app.utils.db import db_session
@ -59,7 +60,18 @@ def render_backtest_review() -> None:
col1, col2 = st.columns(2)
start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date")
end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date")
universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe")
latest_candidates = list_investment_pool(limit=50)
candidate_codes = [item.ts_code for item in latest_candidates]
default_universe = ",".join(candidate_codes) if candidate_codes else "000001.SZ"
universe_text = st.text_input(
"股票列表(逗号分隔)",
value=default_universe,
key="bt_universe",
help="默认载入最新候选池,如需自定义可直接编辑。",
)
if candidate_codes:
st.caption(f"候选池载入 {len(candidate_codes)} 个标的:{''.join(candidate_codes[:10])}{'' if len(candidate_codes)>10 else ''}")
col_target, col_stop, col_hold, col_cap = st.columns(4)
target = col_target.number_input("目标收益0.035 表示 3.5%", value=0.035, step=0.005, format="%.3f", key="bt_target")
stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%", value=-0.015, step=0.005, format="%.3f", key="bt_stop")

View File

@ -2,6 +2,7 @@
from datetime import date, datetime, timedelta
from typing import Dict, List, Optional
import json
import sqlite3
import numpy as np
import pandas as pd
@ -15,6 +16,44 @@ from app.utils.data_access import DataBroker
from app.utils.db import db_session
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "stock_eval"}
def _ensure_investment_pool_schema(conn: sqlite3.Connection) -> None:
"""Ensure investment_pool table has latest optional columns."""
try:
info = conn.execute("PRAGMA table_info(investment_pool)").fetchall()
except sqlite3.Error:
return
columns = {
(row["name"] if isinstance(row, sqlite3.Row) else row[1])
for row in info
if row is not None
}
if "name" not in columns:
try:
conn.execute("ALTER TABLE investment_pool ADD COLUMN name TEXT")
except sqlite3.Error:
pass
if "industry" not in columns:
try:
conn.execute("ALTER TABLE investment_pool ADD COLUMN industry TEXT")
except sqlite3.Error:
pass
if "created_at" not in columns:
try:
conn.execute(
"ALTER TABLE investment_pool ADD COLUMN created_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now'))"
)
except sqlite3.Error:
try:
conn.execute("ALTER TABLE investment_pool ADD COLUMN created_at TEXT")
except sqlite3.Error:
pass
def _get_latest_trading_date() -> date:
"""获取数据库中的最新交易日期"""
@ -510,6 +549,7 @@ def _add_to_stock_pool(
)
with db_session() as conn:
_ensure_investment_pool_schema(conn)
conn.execute("DELETE FROM investment_pool WHERE trade_date = ?", (trade_date,))
if payload:
conn.executemany(

View File

@ -11,7 +11,7 @@ import pandas as pd
import streamlit as st
from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig
from app.utils.portfolio import list_investment_pool
from app.utils.portfolio import InvestmentCandidate, list_investment_pool
from app.utils.db import db_session
from app.ui.shared import (
@ -175,24 +175,39 @@ def render_today_plan() -> None:
).fetchall()
symbols = [row["ts_code"] for row in code_rows]
candidate_records = list_investment_pool(trade_date=trade_date)
if candidate_records:
st.caption(
f"候选池包含 {len(candidate_records)} 个标的:"
+ "".join(item.ts_code for item in candidate_records[:12])
+ ("" if len(candidate_records) > 12 else "")
)
if candidate_records:
candidate_codes = [item.ts_code for item in candidate_records]
symbols = list(dict.fromkeys(candidate_codes + symbols))
detail_tab, assistant_tab = st.tabs(["标的详情", "投资助理模式"])
with assistant_tab:
_render_today_plan_assistant_view(trade_date)
_render_today_plan_assistant_view(trade_date, candidate_records)
with detail_tab:
if not symbols:
st.info("所选交易日暂无 agent_utils 记录。")
else:
_render_today_plan_symbol_view(trade_date, symbols, query)
_render_today_plan_symbol_view(trade_date, symbols, query, candidate_records)
def _render_today_plan_assistant_view(trade_date: str | int | date) -> None:
def _render_today_plan_assistant_view(
trade_date: str | int | date,
candidate_records: List[InvestmentCandidate],
) -> None:
# 确保日期格式为字符串
if isinstance(trade_date, date):
trade_date = trade_date.strftime("%Y%m%d")
st.info("已开启投资助理模式:以下内容为组合级(去标的)建议,不包含任何具体标的代码。")
try:
candidates = list_investment_pool(trade_date=trade_date)
candidates = candidate_records or list_investment_pool(trade_date=trade_date)
if candidates:
scores = [float(item.score or 0.0) for item in candidates]
statuses = [item.status or "UNKNOWN" for item in candidates]
@ -259,6 +274,7 @@ def _render_today_plan_symbol_view(
trade_date: str | int | date,
symbols: List[str],
query_params: Dict[str, List[str]],
candidate_records: List[InvestmentCandidate],
) -> None:
default_ts = query_params.get("code", [symbols[0]])[0]
try:
@ -266,7 +282,9 @@ def _render_today_plan_symbol_view(
except ValueError:
default_ts_idx = 0
ts_code = st.selectbox("标的", symbols, index=default_ts_idx)
batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[])
candidate_code_set = {item.ts_code for item in candidate_records}
default_batch = [code for code in symbols if code in candidate_code_set]
batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=default_batch[:10])
if st.button("一键重评估所有标的", type="primary", width='stretch'):
with st.spinner("正在对所有标的进行重评估,请稍候..."):
@ -318,6 +336,16 @@ def _render_today_plan_symbol_view(
st.info("未查询到详细决策记录,稍后再试。")
return
candidate_map = {item.ts_code: item for item in candidate_records}
candidate_info = candidate_map.get(ts_code)
if candidate_info:
info_cols = st.columns(3)
info_cols[0].metric("候选评分", f"{(candidate_info.score or 0):.3f}")
info_cols[1].metric("状态", candidate_info.status or "-")
info_cols[2].metric("更新时间", candidate_info.created_at or "-")
if candidate_info.rationale:
st.caption(f"候选理由:{candidate_info.rationale}")
try:
feasible_actions = json.loads(rows[0]["feasible"] or "[]")
except (KeyError, TypeError, json.JSONDecodeError):

View File

@ -34,6 +34,7 @@ class InvestmentCandidate:
metadata: Dict[str, Any]
name: Optional[str] = None
industry: Optional[str] = None
created_at: Optional[str] = None
def list_investment_pool(
@ -44,31 +45,55 @@ def list_investment_pool(
) -> List[InvestmentCandidate]:
"""Return investment candidates for the given trade date (latest if None)."""
query = [
"SELECT trade_date, ts_code, score, status, rationale, tags, metadata, name, industry",
"FROM investment_pool",
]
params: List[Any] = []
if trade_date:
query.append("WHERE trade_date = ?")
params.append(trade_date)
else:
query.append(
"WHERE trade_date = (SELECT MAX(trade_date) FROM investment_pool)"
)
if status:
placeholders = ", ".join("?" for _ in status)
query.append(f"AND status IN ({placeholders})")
params.extend(list(status))
query.append("ORDER BY (score IS NULL), score DESC, ts_code")
query.append("LIMIT ?")
params.append(int(limit))
sql = "\n".join(query)
with db_session(read_only=True) as conn:
try:
info = conn.execute("PRAGMA table_info(investment_pool)").fetchall()
except Exception: # noqa: BLE001
LOGGER.exception("无法读取 investment_pool 结构", extra=LOG_EXTRA)
return []
available_columns = {
(row[1] if not isinstance(row, dict) else row.get("name"))
for row in info
}
select_columns = [
"trade_date",
"ts_code",
"score",
"status",
"rationale",
"tags",
"metadata",
]
optional_columns = [col for col in ("name", "industry", "created_at") if col in available_columns]
select_columns.extend(optional_columns)
column_clause = ", ".join(select_columns)
query = [
f"SELECT {column_clause}",
"FROM investment_pool",
]
params: List[Any] = []
if trade_date:
query.append("WHERE trade_date = ?")
params.append(trade_date)
else:
query.append(
"WHERE trade_date = (SELECT MAX(trade_date) FROM investment_pool)"
)
if status:
placeholders = ", ".join("?" for _ in status)
query.append(f"AND status IN ({placeholders})")
params.extend(list(status))
query.append("ORDER BY (score IS NULL), score DESC, ts_code")
query.append("LIMIT ?")
params.append(int(limit))
sql = "\n".join(query)
try:
rows = conn.execute(sql, params).fetchall()
except Exception: # noqa: BLE001
@ -77,6 +102,7 @@ def list_investment_pool(
candidates: List[InvestmentCandidate] = []
for row in rows:
row_keys = set(row.keys()) if hasattr(row, "keys") else set()
candidates.append(
InvestmentCandidate(
trade_date=row["trade_date"],
@ -86,8 +112,9 @@ def list_investment_pool(
rationale=row["rationale"],
tags=list(_loads_or_default(row["tags"], [])),
metadata=dict(_loads_or_default(row["metadata"], {})),
name=row["name"],
industry=row["industry"],
name=row["name"] if "name" in row_keys else None,
industry=row["industry"] if "industry" in row_keys else None,
created_at=row["created_at"] if "created_at" in row_keys else None,
)
)
return candidates