add stock name, industry and timestamp to investment pool schema
This commit is contained in:
parent
3fc563e72a
commit
3563220385
@ -2,6 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import sqlite3
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
@ -970,12 +971,17 @@ class BacktestEngine:
|
|||||||
for code, dept in decision.department_decisions.items()
|
for code, dept in decision.department_decisions.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stock_info = self.data_broker.get_stock_info(context.ts_code, context.trade_date)
|
||||||
|
name = stock_info.get("name") if stock_info else None
|
||||||
|
industry = stock_info.get("industry") if stock_info else None
|
||||||
|
|
||||||
with db_session() as conn:
|
with db_session() as conn:
|
||||||
|
self._ensure_investment_pool_columns(conn)
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"""
|
"""
|
||||||
INSERT OR REPLACE INTO investment_pool
|
INSERT OR REPLACE INTO investment_pool
|
||||||
(trade_date, ts_code, score, status, rationale, tags, metadata)
|
(trade_date, ts_code, score, status, rationale, tags, metadata, name, industry)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
context.trade_date,
|
context.trade_date,
|
||||||
@ -985,9 +991,44 @@ class BacktestEngine:
|
|||||||
summary or None,
|
summary or None,
|
||||||
json.dumps(_department_tags(decision), ensure_ascii=False),
|
json.dumps(_department_tags(decision), ensure_ascii=False),
|
||||||
json.dumps(metadata, ensure_ascii=False),
|
json.dumps(metadata, ensure_ascii=False),
|
||||||
|
name,
|
||||||
|
industry,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_investment_pool_columns(conn: sqlite3.Connection) -> None:
|
||||||
|
try:
|
||||||
|
info = conn.execute("PRAGMA table_info(investment_pool)").fetchall()
|
||||||
|
except sqlite3.Error:
|
||||||
|
return
|
||||||
|
|
||||||
|
columns = {
|
||||||
|
(row[1] if not isinstance(row, sqlite3.Row) else row["name"])
|
||||||
|
for row in info
|
||||||
|
if row is not None
|
||||||
|
}
|
||||||
|
if "name" not in columns:
|
||||||
|
try:
|
||||||
|
conn.execute("ALTER TABLE investment_pool ADD COLUMN name TEXT")
|
||||||
|
except sqlite3.Error:
|
||||||
|
pass
|
||||||
|
if "industry" not in columns:
|
||||||
|
try:
|
||||||
|
conn.execute("ALTER TABLE investment_pool ADD COLUMN industry TEXT")
|
||||||
|
except sqlite3.Error:
|
||||||
|
pass
|
||||||
|
if "created_at" not in columns:
|
||||||
|
try:
|
||||||
|
conn.execute(
|
||||||
|
"ALTER TABLE investment_pool ADD COLUMN created_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now'))"
|
||||||
|
)
|
||||||
|
except sqlite3.Error:
|
||||||
|
try:
|
||||||
|
conn.execute("ALTER TABLE investment_pool ADD COLUMN created_at TEXT")
|
||||||
|
except sqlite3.Error:
|
||||||
|
pass
|
||||||
|
|
||||||
def _persist_portfolio(
|
def _persist_portfolio(
|
||||||
self,
|
self,
|
||||||
trade_date: str,
|
trade_date: str,
|
||||||
|
|||||||
@ -26,6 +26,7 @@ from app.llm.templates import TemplateRegistry
|
|||||||
from app.utils import alerts
|
from app.utils import alerts
|
||||||
from app.utils.config import get_config, save_config
|
from app.utils.config import get_config, save_config
|
||||||
from app.utils.tuning import log_tuning_result
|
from app.utils.tuning import log_tuning_result
|
||||||
|
from app.utils.portfolio import list_investment_pool
|
||||||
|
|
||||||
from app.utils.db import db_session
|
from app.utils.db import db_session
|
||||||
|
|
||||||
@ -59,7 +60,18 @@ def render_backtest_review() -> None:
|
|||||||
col1, col2 = st.columns(2)
|
col1, col2 = st.columns(2)
|
||||||
start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date")
|
start_date = col1.date_input("开始日期", value=default_start, key="bt_start_date")
|
||||||
end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date")
|
end_date = col2.date_input("结束日期", value=default_end, key="bt_end_date")
|
||||||
universe_text = st.text_input("股票列表(逗号分隔)", value="000001.SZ", key="bt_universe")
|
|
||||||
|
latest_candidates = list_investment_pool(limit=50)
|
||||||
|
candidate_codes = [item.ts_code for item in latest_candidates]
|
||||||
|
default_universe = ",".join(candidate_codes) if candidate_codes else "000001.SZ"
|
||||||
|
universe_text = st.text_input(
|
||||||
|
"股票列表(逗号分隔)",
|
||||||
|
value=default_universe,
|
||||||
|
key="bt_universe",
|
||||||
|
help="默认载入最新候选池,如需自定义可直接编辑。",
|
||||||
|
)
|
||||||
|
if candidate_codes:
|
||||||
|
st.caption(f"候选池载入 {len(candidate_codes)} 个标的:{'、'.join(candidate_codes[:10])}{'…' if len(candidate_codes)>10 else ''}")
|
||||||
col_target, col_stop, col_hold, col_cap = st.columns(4)
|
col_target, col_stop, col_hold, col_cap = st.columns(4)
|
||||||
target = col_target.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f", key="bt_target")
|
target = col_target.number_input("目标收益(例:0.035 表示 3.5%)", value=0.035, step=0.005, format="%.3f", key="bt_target")
|
||||||
stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f", key="bt_stop")
|
stop = col_stop.number_input("止损收益(例:-0.015 表示 -1.5%)", value=-0.015, step=0.005, format="%.3f", key="bt_stop")
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
import json
|
import json
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -15,6 +16,44 @@ from app.utils.data_access import DataBroker
|
|||||||
from app.utils.db import db_session
|
from app.utils.db import db_session
|
||||||
from app.utils.logging import get_logger
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOGGER = get_logger(__name__)
|
||||||
|
LOG_EXTRA = {"stage": "stock_eval"}
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_investment_pool_schema(conn: sqlite3.Connection) -> None:
|
||||||
|
"""Ensure investment_pool table has latest optional columns."""
|
||||||
|
try:
|
||||||
|
info = conn.execute("PRAGMA table_info(investment_pool)").fetchall()
|
||||||
|
except sqlite3.Error:
|
||||||
|
return
|
||||||
|
|
||||||
|
columns = {
|
||||||
|
(row["name"] if isinstance(row, sqlite3.Row) else row[1])
|
||||||
|
for row in info
|
||||||
|
if row is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
if "name" not in columns:
|
||||||
|
try:
|
||||||
|
conn.execute("ALTER TABLE investment_pool ADD COLUMN name TEXT")
|
||||||
|
except sqlite3.Error:
|
||||||
|
pass
|
||||||
|
if "industry" not in columns:
|
||||||
|
try:
|
||||||
|
conn.execute("ALTER TABLE investment_pool ADD COLUMN industry TEXT")
|
||||||
|
except sqlite3.Error:
|
||||||
|
pass
|
||||||
|
if "created_at" not in columns:
|
||||||
|
try:
|
||||||
|
conn.execute(
|
||||||
|
"ALTER TABLE investment_pool ADD COLUMN created_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now'))"
|
||||||
|
)
|
||||||
|
except sqlite3.Error:
|
||||||
|
try:
|
||||||
|
conn.execute("ALTER TABLE investment_pool ADD COLUMN created_at TEXT")
|
||||||
|
except sqlite3.Error:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _get_latest_trading_date() -> date:
|
def _get_latest_trading_date() -> date:
|
||||||
"""获取数据库中的最新交易日期"""
|
"""获取数据库中的最新交易日期"""
|
||||||
@ -510,6 +549,7 @@ def _add_to_stock_pool(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with db_session() as conn:
|
with db_session() as conn:
|
||||||
|
_ensure_investment_pool_schema(conn)
|
||||||
conn.execute("DELETE FROM investment_pool WHERE trade_date = ?", (trade_date,))
|
conn.execute("DELETE FROM investment_pool WHERE trade_date = ?", (trade_date,))
|
||||||
if payload:
|
if payload:
|
||||||
conn.executemany(
|
conn.executemany(
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import pandas as pd
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig
|
from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig
|
||||||
from app.utils.portfolio import list_investment_pool
|
from app.utils.portfolio import InvestmentCandidate, list_investment_pool
|
||||||
from app.utils.db import db_session
|
from app.utils.db import db_session
|
||||||
|
|
||||||
from app.ui.shared import (
|
from app.ui.shared import (
|
||||||
@ -175,24 +175,39 @@ def render_today_plan() -> None:
|
|||||||
).fetchall()
|
).fetchall()
|
||||||
symbols = [row["ts_code"] for row in code_rows]
|
symbols = [row["ts_code"] for row in code_rows]
|
||||||
|
|
||||||
|
candidate_records = list_investment_pool(trade_date=trade_date)
|
||||||
|
if candidate_records:
|
||||||
|
st.caption(
|
||||||
|
f"候选池包含 {len(candidate_records)} 个标的:"
|
||||||
|
+ "、".join(item.ts_code for item in candidate_records[:12])
|
||||||
|
+ ("…" if len(candidate_records) > 12 else "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if candidate_records:
|
||||||
|
candidate_codes = [item.ts_code for item in candidate_records]
|
||||||
|
symbols = list(dict.fromkeys(candidate_codes + symbols))
|
||||||
|
|
||||||
detail_tab, assistant_tab = st.tabs(["标的详情", "投资助理模式"])
|
detail_tab, assistant_tab = st.tabs(["标的详情", "投资助理模式"])
|
||||||
with assistant_tab:
|
with assistant_tab:
|
||||||
_render_today_plan_assistant_view(trade_date)
|
_render_today_plan_assistant_view(trade_date, candidate_records)
|
||||||
|
|
||||||
with detail_tab:
|
with detail_tab:
|
||||||
if not symbols:
|
if not symbols:
|
||||||
st.info("所选交易日暂无 agent_utils 记录。")
|
st.info("所选交易日暂无 agent_utils 记录。")
|
||||||
else:
|
else:
|
||||||
_render_today_plan_symbol_view(trade_date, symbols, query)
|
_render_today_plan_symbol_view(trade_date, symbols, query, candidate_records)
|
||||||
|
|
||||||
|
|
||||||
def _render_today_plan_assistant_view(trade_date: str | int | date) -> None:
|
def _render_today_plan_assistant_view(
|
||||||
|
trade_date: str | int | date,
|
||||||
|
candidate_records: List[InvestmentCandidate],
|
||||||
|
) -> None:
|
||||||
# 确保日期格式为字符串
|
# 确保日期格式为字符串
|
||||||
if isinstance(trade_date, date):
|
if isinstance(trade_date, date):
|
||||||
trade_date = trade_date.strftime("%Y%m%d")
|
trade_date = trade_date.strftime("%Y%m%d")
|
||||||
st.info("已开启投资助理模式:以下内容为组合级(去标的)建议,不包含任何具体标的代码。")
|
st.info("已开启投资助理模式:以下内容为组合级(去标的)建议,不包含任何具体标的代码。")
|
||||||
try:
|
try:
|
||||||
candidates = list_investment_pool(trade_date=trade_date)
|
candidates = candidate_records or list_investment_pool(trade_date=trade_date)
|
||||||
if candidates:
|
if candidates:
|
||||||
scores = [float(item.score or 0.0) for item in candidates]
|
scores = [float(item.score or 0.0) for item in candidates]
|
||||||
statuses = [item.status or "UNKNOWN" for item in candidates]
|
statuses = [item.status or "UNKNOWN" for item in candidates]
|
||||||
@ -259,6 +274,7 @@ def _render_today_plan_symbol_view(
|
|||||||
trade_date: str | int | date,
|
trade_date: str | int | date,
|
||||||
symbols: List[str],
|
symbols: List[str],
|
||||||
query_params: Dict[str, List[str]],
|
query_params: Dict[str, List[str]],
|
||||||
|
candidate_records: List[InvestmentCandidate],
|
||||||
) -> None:
|
) -> None:
|
||||||
default_ts = query_params.get("code", [symbols[0]])[0]
|
default_ts = query_params.get("code", [symbols[0]])[0]
|
||||||
try:
|
try:
|
||||||
@ -266,7 +282,9 @@ def _render_today_plan_symbol_view(
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
default_ts_idx = 0
|
default_ts_idx = 0
|
||||||
ts_code = st.selectbox("标的", symbols, index=default_ts_idx)
|
ts_code = st.selectbox("标的", symbols, index=default_ts_idx)
|
||||||
batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=[])
|
candidate_code_set = {item.ts_code for item in candidate_records}
|
||||||
|
default_batch = [code for code in symbols if code in candidate_code_set]
|
||||||
|
batch_symbols = st.multiselect("批量重评估(可多选)", symbols, default=default_batch[:10])
|
||||||
|
|
||||||
if st.button("一键重评估所有标的", type="primary", width='stretch'):
|
if st.button("一键重评估所有标的", type="primary", width='stretch'):
|
||||||
with st.spinner("正在对所有标的进行重评估,请稍候..."):
|
with st.spinner("正在对所有标的进行重评估,请稍候..."):
|
||||||
@ -318,6 +336,16 @@ def _render_today_plan_symbol_view(
|
|||||||
st.info("未查询到详细决策记录,稍后再试。")
|
st.info("未查询到详细决策记录,稍后再试。")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
candidate_map = {item.ts_code: item for item in candidate_records}
|
||||||
|
candidate_info = candidate_map.get(ts_code)
|
||||||
|
if candidate_info:
|
||||||
|
info_cols = st.columns(3)
|
||||||
|
info_cols[0].metric("候选评分", f"{(candidate_info.score or 0):.3f}")
|
||||||
|
info_cols[1].metric("状态", candidate_info.status or "-")
|
||||||
|
info_cols[2].metric("更新时间", candidate_info.created_at or "-")
|
||||||
|
if candidate_info.rationale:
|
||||||
|
st.caption(f"候选理由:{candidate_info.rationale}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
feasible_actions = json.loads(rows[0]["feasible"] or "[]")
|
feasible_actions = json.loads(rows[0]["feasible"] or "[]")
|
||||||
except (KeyError, TypeError, json.JSONDecodeError):
|
except (KeyError, TypeError, json.JSONDecodeError):
|
||||||
|
|||||||
@ -34,6 +34,7 @@ class InvestmentCandidate:
|
|||||||
metadata: Dict[str, Any]
|
metadata: Dict[str, Any]
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
industry: Optional[str] = None
|
industry: Optional[str] = None
|
||||||
|
created_at: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def list_investment_pool(
|
def list_investment_pool(
|
||||||
@ -44,8 +45,33 @@ def list_investment_pool(
|
|||||||
) -> List[InvestmentCandidate]:
|
) -> List[InvestmentCandidate]:
|
||||||
"""Return investment candidates for the given trade date (latest if None)."""
|
"""Return investment candidates for the given trade date (latest if None)."""
|
||||||
|
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
try:
|
||||||
|
info = conn.execute("PRAGMA table_info(investment_pool)").fetchall()
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
LOGGER.exception("无法读取 investment_pool 结构", extra=LOG_EXTRA)
|
||||||
|
return []
|
||||||
|
|
||||||
|
available_columns = {
|
||||||
|
(row[1] if not isinstance(row, dict) else row.get("name"))
|
||||||
|
for row in info
|
||||||
|
}
|
||||||
|
|
||||||
|
select_columns = [
|
||||||
|
"trade_date",
|
||||||
|
"ts_code",
|
||||||
|
"score",
|
||||||
|
"status",
|
||||||
|
"rationale",
|
||||||
|
"tags",
|
||||||
|
"metadata",
|
||||||
|
]
|
||||||
|
optional_columns = [col for col in ("name", "industry", "created_at") if col in available_columns]
|
||||||
|
select_columns.extend(optional_columns)
|
||||||
|
|
||||||
|
column_clause = ", ".join(select_columns)
|
||||||
query = [
|
query = [
|
||||||
"SELECT trade_date, ts_code, score, status, rationale, tags, metadata, name, industry",
|
f"SELECT {column_clause}",
|
||||||
"FROM investment_pool",
|
"FROM investment_pool",
|
||||||
]
|
]
|
||||||
params: List[Any] = []
|
params: List[Any] = []
|
||||||
@ -68,7 +94,6 @@ def list_investment_pool(
|
|||||||
params.append(int(limit))
|
params.append(int(limit))
|
||||||
|
|
||||||
sql = "\n".join(query)
|
sql = "\n".join(query)
|
||||||
with db_session(read_only=True) as conn:
|
|
||||||
try:
|
try:
|
||||||
rows = conn.execute(sql, params).fetchall()
|
rows = conn.execute(sql, params).fetchall()
|
||||||
except Exception: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
@ -77,6 +102,7 @@ def list_investment_pool(
|
|||||||
|
|
||||||
candidates: List[InvestmentCandidate] = []
|
candidates: List[InvestmentCandidate] = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
row_keys = set(row.keys()) if hasattr(row, "keys") else set()
|
||||||
candidates.append(
|
candidates.append(
|
||||||
InvestmentCandidate(
|
InvestmentCandidate(
|
||||||
trade_date=row["trade_date"],
|
trade_date=row["trade_date"],
|
||||||
@ -86,8 +112,9 @@ def list_investment_pool(
|
|||||||
rationale=row["rationale"],
|
rationale=row["rationale"],
|
||||||
tags=list(_loads_or_default(row["tags"], [])),
|
tags=list(_loads_or_default(row["tags"], [])),
|
||||||
metadata=dict(_loads_or_default(row["metadata"], {})),
|
metadata=dict(_loads_or_default(row["metadata"], {})),
|
||||||
name=row["name"],
|
name=row["name"] if "name" in row_keys else None,
|
||||||
industry=row["industry"],
|
industry=row["industry"] if "industry" in row_keys else None,
|
||||||
|
created_at=row["created_at"] if "created_at" in row_keys else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return candidates
|
return candidates
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user