This commit is contained in:
sam 2025-09-27 08:03:33 +08:00
parent 2dc2753827
commit 36322f66db
4 changed files with 227 additions and 27 deletions

View File

@ -23,6 +23,7 @@ Streamlit `自检测试` 页签提供:
- TuShare 小范围拉取测试;
- 一键开机检查(可自动补数并展示覆盖摘要);
- 股票行情可视化(自动加载近段时间价格、成交量,并展示核心指标)。
- 开机检查带进度指示与详细日志,便于排查 TuShare 拉取问题。
`回测与复盘` 页签提供快速回测表单,可调整时间区间、股票池与参数并即时查看回测输出。

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import date, timedelta
from typing import Dict
from typing import Callable, Dict
from app.data.schema import initialize_database
from app.ingest.tushare import collect_data_coverage, ensure_data_coverage
@ -36,7 +36,11 @@ def _default_window(days: int = 365) -> tuple[date, date]:
return start, end
def run_boot_check(days: int = 365, auto_fetch: bool = True) -> CoverageReport:
def run_boot_check(
days: int = 365,
auto_fetch: bool = True,
progress_hook: Callable[[str, float], None] | None = None,
) -> CoverageReport:
"""执行开机自检,必要时自动补数据。"""
initialize_database()
@ -44,7 +48,7 @@ def run_boot_check(days: int = 365, auto_fetch: bool = True) -> CoverageReport:
LOGGER.info("开机检查覆盖窗口:%s%s", start, end)
if auto_fetch:
ensure_data_coverage(start, end)
ensure_data_coverage(start, end, progress_hook=progress_hook)
coverage = collect_data_coverage(start, end)
@ -63,5 +67,7 @@ def run_boot_check(days: int = 365, auto_fetch: bool = True) -> CoverageReport:
report.tables["daily"].get("distinct_days"),
report.expected_trading_days,
)
if progress_hook:
progress_hook("数据覆盖检查完成", 1.0)
return report

View File

@ -5,7 +5,7 @@ import logging
import os
from dataclasses import dataclass
from datetime import date
from typing import Dict, Iterable, List, Optional, Sequence
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple
import pandas as pd
@ -16,10 +16,57 @@ except ImportError: # pragma: no cover - 运行时提示
from app.utils.config import get_config
from app.utils.db import db_session
def _existing_date_range(table: str, date_col: str, ts_code: str | None = None) -> Tuple[str | None, str | None]:
query = f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d FROM {table}"
params: Tuple = ()
if ts_code:
query += " WHERE ts_code = ?"
params = (ts_code,)
with db_session(read_only=True) as conn:
row = conn.execute(query, params).fetchone()
if row is None:
return None, None
return row["min_d"], row["max_d"]
from app.data.schema import initialize_database
LOGGER = logging.getLogger(__name__)
API_DEFAULT_LIMIT = 5000
def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None = None) -> pd.DataFrame:
client = _ensure_client()
limit = limit or API_DEFAULT_LIMIT
frames: List[pd.DataFrame] = []
offset = 0
clean_params = {k: v for k, v in params.items() if v is not None}
LOGGER.info("开始调用 TuShare 接口:%s,参数=%slimit=%s", endpoint, clean_params, limit)
while True:
call = getattr(client, endpoint)
try:
df = call(limit=limit, offset=offset, **clean_params)
except Exception: # noqa: BLE001
LOGGER.exception("TuShare 接口调用异常endpoint=%s offset=%s params=%s", endpoint, offset, clean_params)
raise
if df is None or df.empty:
LOGGER.info("TuShare 返回空数据endpoint=%s offset=%s", endpoint, offset)
break
LOGGER.info("TuShare 返回 %sendpoint=%s offset=%s", len(df), endpoint, offset)
frames.append(df)
if len(df) < limit:
break
offset += limit
if not frames:
return pd.DataFrame()
merged = pd.concat(frames, ignore_index=True)
LOGGER.info("TuShare 调用完成endpoint=%s 总行数=%s", endpoint, len(merged))
return merged
@dataclass
class FetchJob:
@ -218,6 +265,37 @@ def _format_date(value: date) -> str:
return value.strftime("%Y%m%d")
def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str]:
start_str = _format_date(start)
end_str = _format_date(end)
query = (
"SELECT cal_date FROM trade_calendar "
"WHERE exchange = ? AND cal_date BETWEEN ? AND ? AND is_open = 1 ORDER BY cal_date"
)
with db_session(read_only=True) as conn:
rows = conn.execute(query, (exchange, start_str, end_str)).fetchall()
return [row["cal_date"] for row in rows]
def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool:
min_d, max_d = _existing_date_range(table, date_col, ts_code)
if min_d is None or max_d is None:
return False
start_str = _format_date(start)
end_str = _format_date(end)
return min_d <= start_str and max_d >= end_str
def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool:
min_d, max_d = _existing_date_range(table, date_col, ts_code)
if min_d is None or max_d is None:
return False
start_str = _format_date(start)
end_str = _format_date(end)
return min_d <= start_str and max_d >= end_str
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
if df is None or df.empty:
return []
@ -301,17 +379,32 @@ def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]:
if job.granularity != "daily":
raise ValueError(f"暂不支持的粒度:{job.granularity}")
params = {
"start_date": start_date,
"end_date": end_date,
}
if job.ts_codes:
for code in job.ts_codes:
LOGGER.info("拉取 %s 的日线行情(%s-%s", code, start_date, end_date)
frames.append(client.daily(ts_code=code, start_date=start_date, end_date=end_date))
df = _fetch_paginated("daily", {**params, "ts_code": code})
if not df.empty:
frames.append(df)
else:
LOGGER.info("按全市场拉取日线行情(%s-%s", start_date, end_date)
frames.append(client.daily(start_date=start_date, end_date=end_date))
trade_dates = _load_trade_dates(job.start, job.end)
if not trade_dates:
LOGGER.info("本地交易日历缺失,尝试补全后再拉取日线行情")
ensure_trade_calendar(job.start, job.end)
trade_dates = _load_trade_dates(job.start, job.end)
for trade_date in trade_dates:
LOGGER.debug("按交易日拉取日线行情:%s", trade_date)
df = _fetch_paginated("daily", {"trade_date": trade_date})
if not df.empty:
frames.append(df)
if not frames:
return []
df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
df = pd.concat(frames, ignore_index=True)
return _df_to_records(df, _TABLE_COLUMNS["daily"])
@ -320,7 +413,11 @@ def fetch_daily_basic(start: date, end: date, ts_code: Optional[str] = None) ->
start_date = _format_date(start)
end_date = _format_date(end)
LOGGER.info("拉取日线基础指标(%s-%s,股票:%s", start_date, end_date, ts_code or "全部")
df = client.daily_basic(ts_code=ts_code, start_date=start_date, end_date=end_date)
df = _fetch_paginated("daily_basic", {
"ts_code": ts_code,
"start_date": start_date,
"end_date": end_date,
})
return _df_to_records(df, _TABLE_COLUMNS["daily_basic"])
@ -329,7 +426,11 @@ def fetch_adj_factor(start: date, end: date, ts_code: Optional[str] = None) -> I
start_date = _format_date(start)
end_date = _format_date(end)
LOGGER.info("拉取复权因子(%s-%s,股票:%s", start_date, end_date, ts_code or "全部")
df = client.adj_factor(ts_code=ts_code, start_date=start_date, end_date=end_date)
df = _fetch_paginated("adj_factor", {
"ts_code": ts_code,
"start_date": start_date,
"end_date": end_date,
})
return _df_to_records(df, _TABLE_COLUMNS["adj_factor"])
@ -338,7 +439,11 @@ def fetch_suspensions(start: date, end: date, ts_code: Optional[str] = None) ->
start_date = _format_date(start)
end_date = _format_date(end)
LOGGER.info("拉取停复牌信息(%s-%s", start_date, end_date)
df = client.suspend_d(ts_code=ts_code, start_date=start_date, end_date=end_date)
df = _fetch_paginated("suspend_d", {
"ts_code": ts_code,
"start_date": start_date,
"end_date": end_date,
}, limit=2000)
return _df_to_records(df, _TABLE_COLUMNS["suspend"])
@ -356,7 +461,11 @@ def fetch_stk_limit(start: date, end: date, ts_code: Optional[str] = None) -> It
start_date = _format_date(start)
end_date = _format_date(end)
LOGGER.info("拉取涨跌停价格(%s-%s", start_date, end_date)
df = client.stk_limit(ts_code=ts_code, start_date=start_date, end_date=end_date)
df = _fetch_paginated("stk_limit", {
"ts_code": ts_code,
"start_date": start_date,
"end_date": end_date,
})
return _df_to_records(df, _TABLE_COLUMNS["stk_limit"])
@ -412,41 +521,103 @@ def ensure_data_coverage(
ts_codes: Optional[Sequence[str]] = None,
include_limits: bool = True,
force: bool = False,
progress_hook: Callable[[str, float], None] | None = None,
) -> None:
initialize_database()
start_str = _format_date(start)
end_str = _format_date(end)
total_steps = 5 + (1 if include_limits else 0)
current_step = 0
def advance(message: str) -> None:
nonlocal current_step
current_step += 1
progress = min(current_step / total_steps, 1.0)
if progress_hook:
progress_hook(message, progress)
LOGGER.info(message)
advance("准备股票基础信息与交易日历")
ensure_stock_basic()
ensure_trade_calendar(start, end)
codes = tuple(dict.fromkeys(ts_codes)) if ts_codes else tuple()
expected_days = _expected_trading_days(start_str, end_str)
job = FetchJob("daily_autofill", start=start, end=end, ts_codes=codes)
if force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days):
advance("处理日线行情数据")
if codes:
pending_codes: List[str] = []
for code in codes:
if not force and _should_skip_range("daily", "trade_date", start, end, code):
LOGGER.info("股票 %s 的日线已覆盖 %s-%s,跳过", code, start_str, end_str)
continue
pending_codes.append(code)
if pending_codes:
job = FetchJob("daily_autofill", start=start, end=end, ts_codes=tuple(pending_codes))
LOGGER.info("开始拉取日线行情:%s-%s(待补股票 %d 支)", start_str, end_str, len(pending_codes))
save_records("daily", fetch_daily_bars(job))
else:
needs_daily = force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days)
if not needs_daily:
LOGGER.info("日线数据已覆盖 %s-%s,跳过拉取", start_str, end_str)
else:
job = FetchJob("daily_autofill", start=start, end=end)
LOGGER.info("开始拉取日线行情:%s-%s", start_str, end_str)
save_records("daily", fetch_daily_bars(job))
date_cols = {
"daily_basic": "trade_date",
"adj_factor": "trade_date",
"stk_limit": "trade_date",
"suspend": "suspend_date",
}
def _save_with_codes(table: str, fetch_fn) -> None:
date_col = date_cols.get(table, "trade_date")
if codes:
for code in codes:
save_records(table, fetch_fn(start, end, ts_code=code))
if not force and _should_skip_range(table, date_col, start, end, code):
LOGGER.info("%s 股票 %s 已覆盖 %s-%s,跳过", table, code, start_str, end_str)
continue
LOGGER.info("拉取 %s 表数据(股票:%s%s-%s", table, code, start_str, end_str)
try:
rows = fetch_fn(start, end, ts_code=code)
except Exception:
LOGGER.exception("TuShare 拉取失败table=%s code=%s", table, code)
raise
save_records(table, rows)
else:
save_records(table, fetch_fn(start, end))
needs_refresh = force
if not force:
expected = expected_days if table in {"daily_basic", "adj_factor", "stk_limit"} else 0
needs_refresh = _range_needs_refresh(table, date_col, start_str, end_str, expected)
if not needs_refresh:
LOGGER.info("%s 已覆盖 %s-%s,跳过", table, start_str, end_str)
return
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
try:
rows = fetch_fn(start, end)
except Exception:
LOGGER.exception("TuShare 拉取失败table=%s code=全部", table)
raise
save_records(table, rows)
if force or _range_needs_refresh("daily_basic", "trade_date", start_str, end_str, expected_days):
advance("处理日线基础指标数据")
_save_with_codes("daily_basic", fetch_daily_basic)
if force or _range_needs_refresh("adj_factor", "trade_date", start_str, end_str, expected_days):
advance("处理复权因子数据")
_save_with_codes("adj_factor", fetch_adj_factor)
if include_limits and (force or _range_needs_refresh("stk_limit", "trade_date", start_str, end_str, expected_days)):
if include_limits:
advance("处理涨跌停价格数据")
_save_with_codes("stk_limit", fetch_stk_limit)
if force or _range_needs_refresh("suspend", "suspend_date", start_str, end_str):
advance("处理停复牌信息")
_save_with_codes("suspend", fetch_suspensions)
if progress_hook:
progress_hook("数据覆盖检查完成", 1.0)
def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]]:
start_str = _format_date(start)
end_str = _format_date(end)
@ -500,5 +671,11 @@ def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
LOGGER.info("启动 TuShare 拉取任务:%s", job.name)
ensure_data_coverage(job.start, job.end, ts_codes=job.ts_codes, include_limits=include_limits, force=True)
ensure_data_coverage(
job.start,
job.end,
ts_codes=job.ts_codes,
include_limits=include_limits,
force=True,
)
LOGGER.info("任务 %s 完成", job.name)

View File

@ -156,13 +156,29 @@ def render_tests() -> None:
st.divider()
days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30))
if st.button("执行开机检查"):
progress_bar = st.progress(0.0)
status_placeholder = st.empty()
log_placeholder = st.empty()
messages: list[str] = []
def hook(message: str, value: float) -> None:
progress_bar.progress(min(max(value, 0.0), 1.0))
status_placeholder.write(message)
messages.append(message)
with st.spinner("正在执行开机检查..."):
try:
report = run_boot_check(days=days)
report = run_boot_check(days=days, progress_hook=hook)
st.success("开机检查完成,以下为数据覆盖摘要。")
st.json(report.to_dict())
if messages:
log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages))
except Exception as exc: # noqa: BLE001
st.error(f"开机检查失败:{exc}")
if messages:
log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages))
finally:
progress_bar.progress(1.0)
st.divider()
st.subheader("股票行情可视化")