This commit is contained in:
sam 2025-09-27 17:47:16 +08:00
parent a3b16ffa8d
commit 9c7a68d313
2 changed files with 178 additions and 33 deletions

View File

@ -29,6 +29,13 @@ LOG_EXTRA = {"stage": "data_ingest"}
_CALL_QUEUE = deque()
def _normalize_date_str(value: Optional[str]) -> Optional[str]:
if value is None:
return None
text = str(value).strip()
return text or None
def _respect_rate_limit(cfg) -> None:
max_calls = cfg.max_calls_per_minute
if max_calls <= 0:
@ -44,23 +51,6 @@ def _respect_rate_limit(cfg) -> None:
_CALL_QUEUE.append(time.time())
def _existing_date_range(
table: str,
date_col: str,
ts_code: str | None = None,
) -> Tuple[str | None, str | None]:
query = f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d FROM {table}"
params: Tuple = ()
if ts_code:
query += " WHERE ts_code = ?"
params = (ts_code,)
with db_session(read_only=True) as conn:
row = conn.execute(query, params).fetchone()
if row is None:
return None, None
return row["min_d"], row["max_d"]
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
if df is None or df.empty:
return []
@ -352,25 +342,63 @@ def _record_exists(
def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool:
min_d, max_d = _existing_date_range(table, date_col, ts_code)
if min_d is None or max_d is None:
return False
start_str = _format_date(start)
end_str = _format_date(end)
return min_d <= start_str and max_d >= end_str
effective_start = start_str
effective_end = end_str
if ts_code:
list_date, delist_date = _listing_window(ts_code)
if list_date:
effective_start = max(effective_start, list_date)
if delist_date:
effective_end = min(effective_end, delist_date)
if effective_start > effective_end:
LOGGER.debug(
"股票 %s 在目标区间之外,跳过补数",
ts_code,
extra=LOG_EXTRA,
)
return True
stats = _range_stats(table, date_col, effective_start, effective_end, ts_code=ts_code)
else:
stats = _range_stats(table, date_col, effective_start, effective_end)
if stats["min"] is None or stats["max"] is None:
return False
if stats["min"] > effective_start or stats["max"] < effective_end:
return False
if ts_code is None:
expected_days = _expected_trading_days(effective_start, effective_end)
if expected_days and (stats["distinct"] or 0) < expected_days:
return False
return True
def _range_stats(table: str, date_col: str, start_str: str, end_str: str) -> Dict[str, Optional[str]]:
def _range_stats(
table: str,
date_col: str,
start_str: str,
end_str: str,
ts_code: str | None = None,
) -> Dict[str, Optional[str]]:
sql = (
f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d, "
f"COUNT(DISTINCT {date_col}) AS distinct_days FROM {table} "
f"WHERE {date_col} BETWEEN ? AND ?"
)
params: List[object] = [start_str, end_str]
if ts_code:
sql += " AND ts_code = ?"
params.append(ts_code)
with db_session(read_only=True) as conn:
row = conn.execute(sql, (start_str, end_str)).fetchone()
row = conn.execute(sql, tuple(params)).fetchone()
return {
"min": row["min_d"],
"max": row["max_d"],
"min": row["min_d"] if row else None,
"max": row["max_d"] if row else None,
"distinct": row["distinct_days"] if row else 0,
}
@ -392,6 +420,17 @@ def _range_needs_refresh(
return False
def _listing_window(ts_code: str) -> Tuple[Optional[str], Optional[str]]:
with db_session(read_only=True) as conn:
row = conn.execute(
"SELECT list_date, delist_date FROM stock_basic WHERE ts_code = ?",
(ts_code,),
).fetchone()
if not row:
return None, None
return _normalize_date_str(row["list_date"]), _normalize_date_str(row["delist_date"]) # type: ignore[index]
def _calendar_needs_refresh(exchange: str, start_str: str, end_str: str) -> bool:
sql = """
SELECT MIN(cal_date) AS min_d, MAX(cal_date) AS max_d, COUNT(*) AS cnt
@ -421,7 +460,12 @@ def _expected_trading_days(start_str: str, end_str: str, exchange: str = "SSE")
def fetch_stock_basic(exchange: Optional[str] = None, list_status: str = "L") -> Iterable[Dict]:
client = _ensure_client()
LOGGER.info("拉取股票基础信息(交易所:%s,状态:%s", exchange or "全部", list_status)
LOGGER.info(
"拉取股票基础信息(交易所:%s,状态:%s",
exchange or "全部",
list_status,
extra=LOG_EXTRA,
)
fields = "ts_code,symbol,name,area,industry,market,exchange,list_status,list_date,delist_date"
df = client.stock_basic(exchange=exchange, list_status=list_status, fields=fields)
return _df_to_records(df, _TABLE_COLUMNS["stock_basic"])
@ -626,7 +670,13 @@ def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Itera
client = _ensure_client()
start_date = _format_date(start)
end_date = _format_date(end)
LOGGER.info("拉取交易日历(交易所:%s,区间:%s-%s", exchange, start_date, end_date)
LOGGER.info(
"拉取交易日历(交易所:%s,区间:%s-%s",
exchange,
start_date,
end_date,
extra=LOG_EXTRA,
)
df = client.trade_cal(exchange=exchange, start_date=start_date, end_date=end_date)
if df is not None and not df.empty and "is_open" in df.columns:
df["is_open"] = pd.to_numeric(df["is_open"], errors="coerce").fillna(0).astype(int)
@ -672,7 +722,7 @@ def fetch_stk_limit(
def save_records(table: str, rows: Iterable[Dict]) -> None:
items = list(rows)
if not items:
LOGGER.info("%s 没有新增记录,跳过写入", table)
LOGGER.info("%s 没有新增记录,跳过写入", table, extra=LOG_EXTRA)
return
schema = _TABLE_SCHEMAS.get(table)
@ -683,7 +733,7 @@ def save_records(table: str, rows: Iterable[Dict]) -> None:
placeholders = ",".join([f":{col}" for col in columns])
col_clause = ",".join(columns)
LOGGER.info("%s 写入 %d 条记录", table, len(items))
LOGGER.info("%s 写入 %d 条记录", table, len(items), extra=LOG_EXTRA)
with db_session() as conn:
conn.executescript(schema)
conn.executemany(
@ -700,7 +750,11 @@ def ensure_stock_basic(list_status: str = "L") -> None:
(*exchanges, list_status),
).fetchone()
if row and row["cnt"]:
LOGGER.info("股票基础信息已存在 %d 条记录,跳过拉取", row["cnt"])
LOGGER.info(
"股票基础信息已存在 %d 条记录,跳过拉取",
row["cnt"],
extra=LOG_EXTRA,
)
return
for exch in exchanges:
@ -736,7 +790,7 @@ def ensure_data_coverage(
progress = min(current_step / total_steps, 1.0)
if progress_hook:
progress_hook(message, progress)
LOGGER.info(message)
LOGGER.info(message, extra=LOG_EXTRA)
advance("准备股票基础信息与交易日历")
ensure_stock_basic()
@ -824,6 +878,8 @@ def ensure_data_coverage(
if progress_hook:
progress_hook("数据覆盖检查完成", 1.0)
def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]]:
start_str = _format_date(start)
end_str = _format_date(end)
@ -876,7 +932,7 @@ def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
LOGGER.info("启动 TuShare 拉取任务:%s", job.name)
LOGGER.info("启动 TuShare 拉取任务:%s", job.name, extra=LOG_EXTRA)
ensure_data_coverage(
job.start,
job.end,
@ -884,4 +940,4 @@ def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
include_limits=include_limits,
force=True,
)
LOGGER.info("任务 %s 完成", job.name)
LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA)

View File

@ -21,8 +21,11 @@ from app.ingest.tushare import FetchJob, run_ingestion
from app.llm.explain import make_human_card
from app.utils.config import get_config
from app.utils.db import db_session
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "ui"}
def _load_stock_options(limit: int = 500) -> list[str]:
@ -32,6 +35,7 @@ def _load_stock_options(limit: int = 500) -> list[str]:
"SELECT ts_code, name FROM stock_basic WHERE list_status = 'L' ORDER BY ts_code"
).fetchall()
except Exception:
LOGGER.exception("加载股票列表失败", extra=LOG_EXTRA)
return []
options: list[str] = []
for row in rows[:limit]:
@ -39,6 +43,7 @@ def _load_stock_options(limit: int = 500) -> list[str]:
name = row["name"] or ""
label = f"{code} | {name}" if name else code
options.append(label)
LOGGER.info("加载股票选项完成,数量=%s", len(options), extra=LOG_EXTRA)
return options
@ -47,6 +52,13 @@ def _parse_ts_code(selection: str) -> str:
def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame:
LOGGER.info(
"加载行情数据ts_code=%s start=%s end=%s",
ts_code,
start,
end,
extra=LOG_EXTRA,
)
start_str = start.strftime('%Y%m%d')
end_str = end.strftime('%Y%m%d')
range_query = (
@ -62,24 +74,47 @@ def _load_daily_frame(ts_code: str, start: date, end: date) -> pd.DataFrame:
if df.empty:
df = pd.read_sql_query(fallback_query, conn, params=(ts_code,))
if df.empty:
LOGGER.warning(
"行情数据为空ts_code=%s start=%s end=%s",
ts_code,
start,
end,
extra=LOG_EXTRA,
)
return df
df = df.sort_values('trade_date')
df['trade_date'] = pd.to_datetime(df['trade_date'])
df.set_index('trade_date', inplace=True)
LOGGER.info("行情数据加载完成:条数=%s", len(df), extra=LOG_EXTRA)
return df
def render_today_plan() -> None:
LOGGER.info("渲染今日计划页面", extra=LOG_EXTRA)
st.header("今日计划")
st.write("待接入候选池筛选与多智能体决策结果。")
sample = make_human_card("000001.SZ", "2025-01-01", {"decisions": []})
LOGGER.debug("示例卡片内容:%s", sample, extra=LOG_EXTRA)
st.json(sample)
def render_backtest() -> None:
LOGGER.info("渲染回测页面", extra=LOG_EXTRA)
st.header("回测与复盘")
st.write("在此运行回测、展示净值曲线与代理贡献。")
default_start = date(2020, 1, 1)
default_end = date(2020, 3, 31)
LOGGER.debug(
"回测默认参数start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
default_start,
default_end,
"000001.SZ",
0.035,
-0.015,
10,
extra=LOG_EXTRA,
)
col1, col2 = st.columns(2)
start_date = col1.date_input("开始日期", value=default_start)
@ -88,11 +123,32 @@ def render_backtest() -> None:
target = st.number_input("目标收益0.035 表示 3.5%", value=0.035, step=0.005, format="%.3f")
stop = st.number_input("止损收益(例:-0.015 表示 -1.5%", value=-0.015, step=0.005, format="%.3f")
hold_days = st.number_input("持有期(交易日)", value=10, step=1)
LOGGER.debug(
"当前回测表单输入start=%s end=%s universe_text=%s target=%.3f stop=%.3f hold_days=%s",
start_date,
end_date,
universe_text,
target,
stop,
hold_days,
extra=LOG_EXTRA,
)
if st.button("运行回测"):
LOGGER.info("用户点击运行回测按钮", extra=LOG_EXTRA)
with st.spinner("正在执行回测..."):
try:
universe = [code.strip() for code in universe_text.split(',') if code.strip()]
LOGGER.info(
"回测参数start=%s end=%s universe=%s target=%s stop=%s hold_days=%s",
start_date,
end_date,
universe,
target,
stop,
hold_days,
extra=LOG_EXTRA,
)
cfg = BtConfig(
id="streamlit_demo",
name="Streamlit Demo Strategy",
@ -106,39 +162,55 @@ def render_backtest() -> None:
},
)
result = run_backtest(cfg)
LOGGER.info(
"回测完成nav_records=%s trades=%s",
len(result.nav_series),
len(result.trades),
extra=LOG_EXTRA,
)
st.success("回测执行完成,详见回测结果摘要。")
st.json({"nav_records": result.nav_series, "trades": result.trades})
except Exception as exc: # noqa: BLE001
LOGGER.exception("回测执行失败", extra=LOG_EXTRA)
st.error(f"回测执行失败:{exc}")
def render_settings() -> None:
LOGGER.info("渲染设置页面", extra=LOG_EXTRA)
st.header("数据与设置")
cfg = get_config()
LOGGER.debug("当前 TuShare Token 是否已配置=%s", bool(cfg.tushare_token), extra=LOG_EXTRA)
token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password")
if st.button("保存设置"):
LOGGER.info("保存设置按钮被点击", extra=LOG_EXTRA)
cfg.tushare_token = token.strip() or None
LOGGER.info("TuShare Token 更新,是否为空=%s", cfg.tushare_token is None, extra=LOG_EXTRA)
st.success("设置已保存,仅在当前会话生效。")
st.write("新闻源开关与数据库备份将在此配置。")
def render_tests() -> None:
LOGGER.info("渲染自检页面", extra=LOG_EXTRA)
st.header("自检测试")
st.write("用于快速检查数据库与数据拉取是否正常工作。")
if st.button("测试数据库初始化"):
LOGGER.info("点击测试数据库初始化按钮", extra=LOG_EXTRA)
with st.spinner("正在检查数据库..."):
result = initialize_database()
if result.skipped:
LOGGER.info("数据库已存在,无需初始化", extra=LOG_EXTRA)
st.success("数据库已存在,检查通过。")
else:
LOGGER.info("数据库初始化完成,执行语句数=%s", result.executed, extra=LOG_EXTRA)
st.success(f"数据库初始化完成,共执行 {result.executed} 条语句。")
st.divider()
if st.button("测试 TuShare 拉取(示例 2024-01-01 至 2024-01-03"):
LOGGER.info("点击示例 TuShare 拉取按钮", extra=LOG_EXTRA)
with st.spinner("正在调用 TuShare 接口..."):
try:
run_ingestion(
@ -150,14 +222,17 @@ def render_tests() -> None:
),
include_limits=False,
)
LOGGER.info("示例 TuShare 拉取成功", extra=LOG_EXTRA)
st.success("TuShare 示例拉取完成,数据已写入数据库。")
except Exception as exc: # noqa: BLE001
LOGGER.exception("示例 TuShare 拉取失败", extra=LOG_EXTRA)
st.error(f"拉取失败:{exc}")
st.info("注意TuShare 拉取依赖网络与 Token若环境未配置将出现错误提示。")
st.divider()
days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30))
LOGGER.debug("检查窗口天数=%s", days, extra=LOG_EXTRA)
cfg = get_config()
force_refresh = st.checkbox(
"强制刷新数据(关闭增量跳过)",
@ -166,8 +241,10 @@ def render_tests() -> None:
)
if force_refresh != cfg.force_refresh:
cfg.force_refresh = force_refresh
LOGGER.info("更新 force_refresh=%s", force_refresh, extra=LOG_EXTRA)
if st.button("执行开机检查"):
LOGGER.info("点击执行开机检查按钮", extra=LOG_EXTRA)
progress_bar = st.progress(0.0)
status_placeholder = st.empty()
log_placeholder = st.empty()
@ -177,6 +254,7 @@ def render_tests() -> None:
progress_bar.progress(min(max(value, 0.0), 1.0))
status_placeholder.write(message)
messages.append(message)
LOGGER.debug("开机检查进度:%s -> %.2f", message, value, extra=LOG_EXTRA)
with st.spinner("正在执行开机检查..."):
try:
@ -185,11 +263,13 @@ def render_tests() -> None:
progress_hook=hook,
force_refresh=force_refresh,
)
LOGGER.info("开机检查成功", extra=LOG_EXTRA)
st.success("开机检查完成,以下为数据覆盖摘要。")
st.json(report.to_dict())
if messages:
log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages))
except Exception as exc: # noqa: BLE001
LOGGER.exception("开机检查失败", extra=LOG_EXTRA)
st.error(f"开机检查失败:{exc}")
if messages:
log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages))
@ -204,15 +284,19 @@ def render_tests() -> None:
if options:
selection = st.selectbox("选择股票", options, index=0)
ts_code = _parse_ts_code(selection)
LOGGER.debug("选择股票:%s", ts_code, extra=LOG_EXTRA)
else:
ts_code = st.text_input("输入股票代码(如 000001.SZ", value=default_code).strip().upper()
LOGGER.debug("输入股票:%s", ts_code, extra=LOG_EXTRA)
viz_col1, viz_col2 = st.columns(2)
default_start = date.today() - timedelta(days=180)
start_date = viz_col1.date_input("开始日期", value=default_start, key="viz_start")
end_date = viz_col2.date_input("结束日期", value=date.today(), key="viz_end")
LOGGER.debug("行情可视化日期范围:%s-%s", start_date, end_date, extra=LOG_EXTRA)
if start_date > end_date:
LOGGER.warning("无效日期范围:%s>%s", start_date, end_date, extra=LOG_EXTRA)
st.error("开始日期不能晚于结束日期")
return
@ -220,10 +304,12 @@ def render_tests() -> None:
try:
df = _load_daily_frame(ts_code, start_date, end_date)
except Exception as exc: # noqa: BLE001
LOGGER.exception("加载行情数据失败", extra=LOG_EXTRA)
st.error(f"读取数据失败:{exc}")
return
if df.empty:
LOGGER.warning("指定区间无行情数据:%s %s-%s", ts_code, start_date, end_date, extra=LOG_EXTRA)
st.warning("未查询到该区间的交易数据,请确认数据库已拉取对应日线。")
return
@ -309,11 +395,14 @@ def render_tests() -> None:
st.caption("提示:成交量单位为手,成交额以千元显示。箱线图按月展示收盘价分布。")
st.dataframe(df_reset.tail(20), width='stretch')
LOGGER.info("行情可视化完成,展示行数=%s", len(df_reset), extra=LOG_EXTRA)
def main() -> None:
LOGGER.info("初始化 Streamlit UI", extra=LOG_EXTRA)
st.set_page_config(page_title="多智能体投资助理", layout="wide")
tabs = st.tabs(["今日计划", "回测与复盘", "数据与设置", "自检测试"])
LOGGER.debug("Tabs 初始化完成:%s", ["今日计划", "回测与复盘", "数据与设置", "自检测试"], extra=LOG_EXTRA)
with tabs[0]:
render_today_plan()
with tabs[1]: