add stock name, industry and timestamp to investment pool schema
This commit is contained in:
parent
3fc563e72a
commit
3563220385
@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date, datetime
|
||||
@ -970,12 +971,17 @@ class BacktestEngine:
|
||||
for code, dept in decision.department_decisions.items()
|
||||
}
|
||||
|
||||
stock_info = self.data_broker.get_stock_info(context.ts_code, context.trade_date)
|
||||
name = stock_info.get("name") if stock_info else None
|
||||
industry = stock_info.get("industry") if stock_info else None
|
||||
|
||||
with db_session() as conn:
|
||||
self._ensure_investment_pool_columns(conn)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO investment_pool
|
||||
(trade_date, ts_code, score, status, rationale, tags, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
(trade_date, ts_code, score, status, rationale, tags, metadata, name, industry)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
context.trade_date,
|
||||
@ -985,9 +991,44 @@ class BacktestEngine:
|
||||
summary or None,
|
||||
json.dumps(_department_tags(decision), ensure_ascii=False),
|
||||
json.dumps(metadata, ensure_ascii=False),
|
||||
name,
|
||||
industry,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_investment_pool_columns(conn: sqlite3.Connection) -> None:
|
||||
try:
|
||||
info = conn.execute("PRAGMA table_info(investment_pool)").fetchall()
|
||||
except sqlite3.Error:
|
||||
return
|
||||
|
||||
columns = {
|
||||
(row[1] if not isinstance(row, sqlite3.Row) else row["name"])
|
||||
for row in info
|
||||
if row is not None
|
||||
}
|
||||
if "name" not in columns:
|
||||
try:
|
||||
conn.execute("ALTER TABLE investment_pool ADD COLUMN name TEXT")
|
||||
except sqlite3.Error:
|
||||
pass
|
||||
if "industry" not in columns:
|
||||
try:
|
||||
conn.execute("ALTER TABLE investment_pool ADD COLUMN industry TEXT")
|
||||
except sqlite3.Error:
|
||||
pass
|
||||
if "created_at" not in columns:
|
||||
try:
|
||||
conn.execute(
|
||||
"ALTER TABLE investment_pool ADD COLUMN created_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now'))"
|
||||
)
|
||||
except sqlite3.Error:
|
||||
try:
|
||||
conn.execute("ALTER TABLE investment_pool ADD COLUMN created_at TEXT")
|
||||
except sqlite3.Error:
|
||||
pass
|
||||
|
||||
def _persist_portfolio(
|
||||
self,
|
||||
trade_date: str,
|
||||
|
||||
@ -26,6 +26,7 @@ from app.llm.templates import TemplateRegistry
|
||||
from app.utils import alerts
|
||||
from app.utils.config import get_config, save_config
|
||||
from app.utils.tuning import log_tuning_result
|
||||
from app.utils.portfolio import list_investment_pool
|
||||
|
||||
from app.utils.db import db_session
|
||||
|
||||
@ -59,7 +60,18 @@ def render_backtest_review() -> None:
|
||||
col1, col2 = st.columns(2)
|
||||
start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date")
|
||||
end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date")
|
||||
universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe")
|
||||
|
||||
latest_candidates = list_investment_pool(limit=50)
|
||||
candidate_codes = [item.ts_code for item in latest_candidates]
|
||||
default_universe = ",".join(candidate_codes) if candidate_codes else "000001.SZ"
|
||||
universe_text = st.text_input(
|
||||
"股票列表(逗号分隔)",
|
||||
value=default_universe,
|
||||
key="bt_universe",
|
||||
help="默认载入最新候选池,如需自定义可直接编辑。",
|
||||
)
|
||||
if candidate_codes:
|
||||
st.caption(f"候选池载入 {len(candidate_codes)} 个标的:{'、'.join(candidate_codes[:10])}{'…' if len(candidate_codes)>10 else ''}")
|
||||
col_target, col_stop, col_hold, col_cap = st.columns(4)
|
||||
target = col_target.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f", key="bt_target")
|
||||
stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f", key="bt_stop")
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
import json
|
||||
import sqlite3
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -15,6 +16,44 @@ from app.utils.data_access import DataBroker
|
||||
from app.utils.db import db_session
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "stock_eval"}
|
||||
|
||||
|
||||
def _ensure_investment_pool_schema(conn: sqlite3.Connection) -> None:
|
||||
"""Ensure investment_pool table has latest optional columns."""
|
||||
try:
|
||||
info = conn.execute("PRAGMA table_info(investment_pool)").fetchall()
|
||||
except sqlite3.Error:
|
||||
return
|
||||
|
||||
columns = {
|
||||
(row["name"] if isinstance(row, sqlite3.Row) else row[1])
|
||||
for row in info
|
||||
if row is not None
|
||||
}
|
||||
|
||||
if "name" not in columns:
|
||||
try:
|
||||
conn.execute("ALTER TABLE investment_pool ADD COLUMN name TEXT")
|
||||
except sqlite3.Error:
|
||||
pass
|
||||
if "industry" not in columns:
|
||||
try:
|
||||
conn.execute("ALTER TABLE investment_pool ADD COLUMN industry TEXT")
|
||||
except sqlite3.Error:
|
||||
pass
|
||||
if "created_at" not in columns:
|
||||
try:
|
||||
conn.execute(
|
||||
"ALTER TABLE investment_pool ADD COLUMN created_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now'))"
|
||||
)
|
||||
except sqlite3.Error:
|
||||
try:
|
||||
conn.execute("ALTER TABLE investment_pool ADD COLUMN created_at TEXT")
|
||||
except sqlite3.Error:
|
||||
pass
|
||||
|
||||
|
||||
def _get_latest_trading_date() -> date:
|
||||
"""获取数据库中的最新交易日期"""
|
||||
@ -510,6 +549,7 @@ def _add_to_stock_pool(
|
||||
)
|
||||
|
||||
with db_session() as conn:
|
||||
_ensure_investment_pool_schema(conn)
|
||||
conn.execute("DELETE FROM investment_pool WHERE trade_date = ?", (trade_date,))
|
||||
if payload:
|
||||
conn.executemany(
|
||||
|
||||
@ -11,7 +11,7 @@ import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig
|
||||
from app.utils.portfolio import list_investment_pool
|
||||
from app.utils.portfolio import InvestmentCandidate, list_investment_pool
|
||||
from app.utils.db import db_session
|
||||
|
||||
from app.ui.shared import (
|
||||
@ -175,24 +175,39 @@ def render_today_plan() -> None:
|
||||
).fetchall()
|
||||
symbols = [row["ts_code"] for row in code_rows]
|
||||
|
||||
candidate_records = list_investment_pool(trade_date=trade_date)
|
||||
if candidate_records:
|
||||
st.caption(
|
||||
f"候选池包含 {len(candidate_records)} 个标的:"
|
||||
+ "、".join(item.ts_code for item in candidate_records[:12])
|
||||
+ ("…" if len(candidate_records) > 12 else "")
|
||||
)
|
||||
|
||||
if candidate_records:
|
||||
candidate_codes = [item.ts_code for item in candidate_records]
|
||||
symbols = list(dict.fromkeys(candidate_codes + symbols))
|
||||
|
||||
detail_tab, assistant_tab = st.tabs(["标的详情", "投资助理模式"])
|
||||
with assistant_tab:
|
||||
_render_today_plan_assistant_view(trade_date)
|
||||
_render_today_plan_assistant_view(trade_date, candidate_records)
|
||||
|
||||
with detail_tab:
|
||||
if not symbols:
|
||||
st.info("所选交易日暂无 agent_utils 记录。")
|
||||
else:
|
||||
_render_today_plan_symbol_view(trade_date, symbols, query)
|
||||
_render_today_plan_symbol_view(trade_date, symbols, query, candidate_records)
|
||||
|
||||
|
||||
def _render_today_plan_assistant_view(trade_date: str | int | date) -> None:
|
||||
def _render_today_plan_assistant_view(
|
||||
trade_date: str | int | date,
|
||||
candidate_records: List[InvestmentCandidate],
|
||||
) -> None:
|
||||
# 确保日期格式为字符串
|
||||
if isinstance(trade_date, date):
|
||||
trade_date = trade_date.strftime("%Y%m%d")
|
||||
st.info("已开启投资助理模式:以下内容为组合级(去标的)建议,不包含任何具体标的代码。")
|
||||
try:
|
||||
candidates = list_investment_pool(trade_date=trade_date)
|
||||
candidates = candidate_records or list_investment_pool(trade_date=trade_date)
|
||||
if candidates:
|
||||
scores = [float(item.score or 0.0) for item in candidates]
|
||||
statuses = [item.status or "UNKNOWN" for item in candidates]
|
||||
@ -259,6 +274,7 @@ def _render_today_plan_symbol_view(
|
||||
trade_date: str | int | date,
|
||||
symbols: List[str],
|
||||
query_params: Dict[str, List[str]],
|
||||
candidate_records: List[InvestmentCandidate],
|
||||
) -> None:
|
||||
default_ts = query_params.get("code", [symbols[0]])[0]
|
||||
try:
|
||||
@ -266,7 +282,9 @@ def _render_today_plan_symbol_view(
|
||||
except ValueError:
|
||||
default_ts_idx = 0
|
||||
ts_code = st.selectbox("标的", symbols, index=default_ts_idx)
|
||||
batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[])
|
||||
candidate_code_set = {item.ts_code for item in candidate_records}
|
||||
default_batch = [code for code in symbols if code in candidate_code_set]
|
||||
batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=default_batch[:10])
|
||||
|
||||
if st.button("一键重评估所有标的", type="primary", width='stretch'):
|
||||
with st.spinner("正在对所有标的进行重评估,请稍候..."):
|
||||
@ -318,6 +336,16 @@ def _render_today_plan_symbol_view(
|
||||
st.info("未查询到详细决策记录,稍后再试。")
|
||||
return
|
||||
|
||||
candidate_map = {item.ts_code: item for item in candidate_records}
|
||||
candidate_info = candidate_map.get(ts_code)
|
||||
if candidate_info:
|
||||
info_cols = st.columns(3)
|
||||
info_cols[0].metric("候选评分", f"{(candidate_info.score or 0):.3f}")
|
||||
info_cols[1].metric("状态", candidate_info.status or "-")
|
||||
info_cols[2].metric("更新时间", candidate_info.created_at or "-")
|
||||
if candidate_info.rationale:
|
||||
st.caption(f"候选理由:{candidate_info.rationale}")
|
||||
|
||||
try:
|
||||
feasible_actions = json.loads(rows[0]["feasible"] or "[]")
|
||||
except (KeyError, TypeError, json.JSONDecodeError):
|
||||
|
||||
@ -34,6 +34,7 @@ class InvestmentCandidate:
|
||||
metadata: Dict[str, Any]
|
||||
name: Optional[str] = None
|
||||
industry: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
|
||||
|
||||
def list_investment_pool(
|
||||
@ -44,31 +45,55 @@ def list_investment_pool(
|
||||
) -> List[InvestmentCandidate]:
|
||||
"""Return investment candidates for the given trade date (latest if None)."""
|
||||
|
||||
query = [
|
||||
"SELECT trade_date, ts_code, score, status, rationale, tags, metadata, name, industry",
|
||||
"FROM investment_pool",
|
||||
]
|
||||
params: List[Any] = []
|
||||
|
||||
if trade_date:
|
||||
query.append("WHERE trade_date = ?")
|
||||
params.append(trade_date)
|
||||
else:
|
||||
query.append(
|
||||
"WHERE trade_date = (SELECT MAX(trade_date) FROM investment_pool)"
|
||||
)
|
||||
|
||||
if status:
|
||||
placeholders = ", ".join("?" for _ in status)
|
||||
query.append(f"AND status IN ({placeholders})")
|
||||
params.extend(list(status))
|
||||
|
||||
query.append("ORDER BY (score IS NULL), score DESC, ts_code")
|
||||
query.append("LIMIT ?")
|
||||
params.append(int(limit))
|
||||
|
||||
sql = "\n".join(query)
|
||||
with db_session(read_only=True) as conn:
|
||||
try:
|
||||
info = conn.execute("PRAGMA table_info(investment_pool)").fetchall()
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("无法读取 investment_pool 结构", extra=LOG_EXTRA)
|
||||
return []
|
||||
|
||||
available_columns = {
|
||||
(row[1] if not isinstance(row, dict) else row.get("name"))
|
||||
for row in info
|
||||
}
|
||||
|
||||
select_columns = [
|
||||
"trade_date",
|
||||
"ts_code",
|
||||
"score",
|
||||
"status",
|
||||
"rationale",
|
||||
"tags",
|
||||
"metadata",
|
||||
]
|
||||
optional_columns = [col for col in ("name", "industry", "created_at") if col in available_columns]
|
||||
select_columns.extend(optional_columns)
|
||||
|
||||
column_clause = ", ".join(select_columns)
|
||||
query = [
|
||||
f"SELECT {column_clause}",
|
||||
"FROM investment_pool",
|
||||
]
|
||||
params: List[Any] = []
|
||||
|
||||
if trade_date:
|
||||
query.append("WHERE trade_date = ?")
|
||||
params.append(trade_date)
|
||||
else:
|
||||
query.append(
|
||||
"WHERE trade_date = (SELECT MAX(trade_date) FROM investment_pool)"
|
||||
)
|
||||
|
||||
if status:
|
||||
placeholders = ", ".join("?" for _ in status)
|
||||
query.append(f"AND status IN ({placeholders})")
|
||||
params.extend(list(status))
|
||||
|
||||
query.append("ORDER BY (score IS NULL), score DESC, ts_code")
|
||||
query.append("LIMIT ?")
|
||||
params.append(int(limit))
|
||||
|
||||
sql = "\n".join(query)
|
||||
try:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
except Exception: # noqa: BLE001
|
||||
@ -77,6 +102,7 @@ def list_investment_pool(
|
||||
|
||||
candidates: List[InvestmentCandidate] = []
|
||||
for row in rows:
|
||||
row_keys = set(row.keys()) if hasattr(row, "keys") else set()
|
||||
candidates.append(
|
||||
InvestmentCandidate(
|
||||
trade_date=row["trade_date"],
|
||||
@ -86,8 +112,9 @@ def list_investment_pool(
|
||||
rationale=row["rationale"],
|
||||
tags=list(_loads_or_default(row["tags"], [])),
|
||||
metadata=dict(_loads_or_default(row["metadata"], {})),
|
||||
name=row["name"],
|
||||
industry=row["industry"],
|
||||
name=row["name"] if "name" in row_keys else None,
|
||||
industry=row["industry"] if "industry" in row_keys else None,
|
||||
created_at=row["created_at"] if "created_at" in row_keys else None,
|
||||
)
|
||||
)
|
||||
return candidates
|
||||
|
||||
Loading…
Reference in New Issue
Block a user