diff --git a/README.md b/README.md index b8bd6ca..f1a08ec 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ Streamlit `自检测试` 页签提供: - TuShare 小范围拉取测试; - 一键开机检查(可自动补数并展示覆盖摘要); - 股票行情可视化(自动加载近段时间价格、成交量,并展示核心指标)。 +- 开机检查带进度指示与详细日志,便于排查 TuShare 拉取问题。 `回测与复盘` 页签提供快速回测表单,可调整时间区间、股票池与参数并即时查看回测输出。 diff --git a/app/ingest/checker.py b/app/ingest/checker.py index 4e42daf..b160e08 100644 --- a/app/ingest/checker.py +++ b/app/ingest/checker.py @@ -4,7 +4,7 @@ from __future__ import annotations import logging from dataclasses import dataclass from datetime import date, timedelta -from typing import Dict +from typing import Callable, Dict from app.data.schema import initialize_database from app.ingest.tushare import collect_data_coverage, ensure_data_coverage @@ -36,7 +36,11 @@ def _default_window(days: int = 365) -> tuple[date, date]: return start, end -def run_boot_check(days: int = 365, auto_fetch: bool = True) -> CoverageReport: +def run_boot_check( + days: int = 365, + auto_fetch: bool = True, + progress_hook: Callable[[str, float], None] | None = None, +) -> CoverageReport: """执行开机自检,必要时自动补数据。""" initialize_database() @@ -44,7 +48,7 @@ def run_boot_check(days: int = 365, auto_fetch: bool = True) -> CoverageReport: LOGGER.info("开机检查覆盖窗口:%s 至 %s", start, end) if auto_fetch: - ensure_data_coverage(start, end) + ensure_data_coverage(start, end, progress_hook=progress_hook) coverage = collect_data_coverage(start, end) @@ -63,5 +67,7 @@ def run_boot_check(days: int = 365, auto_fetch: bool = True) -> CoverageReport: report.tables["daily"].get("distinct_days"), report.expected_trading_days, ) + if progress_hook: + progress_hook("数据覆盖检查完成", 1.0) return report diff --git a/app/ingest/tushare.py b/app/ingest/tushare.py index a69cd97..1f9985f 100644 --- a/app/ingest/tushare.py +++ b/app/ingest/tushare.py @@ -5,7 +5,7 @@ import logging import os from dataclasses import dataclass from datetime import date -from typing import Dict, Iterable, List, Optional, Sequence +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple import pandas as pd @@ -16,10 +16,57 @@ except ImportError: # pragma: no cover - 运行时提示 from app.utils.config import get_config from app.utils.db import db_session + + +def _existing_date_range(table: str, date_col: str, ts_code: str | None = None) -> Tuple[str | None, str | None]: + query = f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d FROM {table}" + params: Tuple = () + if ts_code: + query += " WHERE ts_code = ?" + params = (ts_code,) + with db_session(read_only=True) as conn: + row = conn.execute(query, params).fetchone() + if row is None: + return None, None + return row["min_d"], row["max_d"] + + + from app.data.schema import initialize_database LOGGER = logging.getLogger(__name__) +API_DEFAULT_LIMIT = 5000 + + +def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None = None) -> pd.DataFrame: + client = _ensure_client() + limit = limit or API_DEFAULT_LIMIT + frames: List[pd.DataFrame] = [] + offset = 0 + clean_params = {k: v for k, v in params.items() if v is not None} + LOGGER.info("开始调用 TuShare 接口:%s,参数=%s,limit=%s", endpoint, clean_params, limit) + while True: + call = getattr(client, endpoint) + try: + df = call(limit=limit, offset=offset, **clean_params) + except Exception: # noqa: BLE001 + LOGGER.exception("TuShare 接口调用异常:endpoint=%s offset=%s params=%s", endpoint, offset, clean_params) + raise + if df is None or df.empty: + LOGGER.info("TuShare 返回空数据:endpoint=%s offset=%s", endpoint, offset) + break + LOGGER.info("TuShare 返回 %s 行:endpoint=%s offset=%s", len(df), endpoint, offset) + frames.append(df) + if len(df) < limit: + break + offset += limit + if not frames: + return pd.DataFrame() + merged = pd.concat(frames, ignore_index=True) + LOGGER.info("TuShare 调用完成:endpoint=%s 总行数=%s", endpoint, len(merged)) + return merged + @dataclass class FetchJob: @@ -218,6 +265,37 @@ def _format_date(value: date) -> str: return value.strftime("%Y%m%d") +def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str]: + start_str = _format_date(start) + end_str = _format_date(end) + query = ( + "SELECT cal_date FROM trade_calendar " + "WHERE exchange = ? AND cal_date BETWEEN ? AND ? AND is_open = 1 ORDER BY cal_date" + ) + with db_session(read_only=True) as conn: + rows = conn.execute(query, (exchange, start_str, end_str)).fetchall() + return [row["cal_date"] for row in rows] + + + +def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool: + min_d, max_d = _existing_date_range(table, date_col, ts_code) + if min_d is None or max_d is None: + return False + start_str = _format_date(start) + end_str = _format_date(end) + return min_d <= start_str and max_d >= end_str + + +def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool: + min_d, max_d = _existing_date_range(table, date_col, ts_code) + if min_d is None or max_d is None: + return False + start_str = _format_date(start) + end_str = _format_date(end) + return min_d <= start_str and max_d >= end_str + + def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]: if df is None or df.empty: return [] @@ -301,17 +379,32 @@ def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]: if job.granularity != "daily": raise ValueError(f"暂不支持的粒度:{job.granularity}") + params = { + "start_date": start_date, + "end_date": end_date, + } + if job.ts_codes: for code in job.ts_codes: LOGGER.info("拉取 %s 的日线行情(%s-%s)", code, start_date, end_date) - frames.append(client.daily(ts_code=code, start_date=start_date, end_date=end_date)) + df = _fetch_paginated("daily", {**params, "ts_code": code}) + if not df.empty: + frames.append(df) else: - LOGGER.info("按全市场拉取日线行情(%s-%s)", start_date, end_date) - frames.append(client.daily(start_date=start_date, end_date=end_date)) + trade_dates = _load_trade_dates(job.start, job.end) + if not trade_dates: + LOGGER.info("本地交易日历缺失,尝试补全后再拉取日线行情") + ensure_trade_calendar(job.start, job.end) + trade_dates = _load_trade_dates(job.start, job.end) + for trade_date in trade_dates: + LOGGER.debug("按交易日拉取日线行情:%s", trade_date) + df = _fetch_paginated("daily", {"trade_date": trade_date}) + if not df.empty: + frames.append(df) if not frames: return [] - df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0] + df = pd.concat(frames, ignore_index=True) return _df_to_records(df, _TABLE_COLUMNS["daily"]) @@ -320,7 +413,11 @@ def fetch_daily_basic(start: date, end: date, ts_code: Optional[str] = None) -> start_date = _format_date(start) end_date = _format_date(end) LOGGER.info("拉取日线基础指标(%s-%s,股票:%s)", start_date, end_date, ts_code or "全部") - df = client.daily_basic(ts_code=ts_code, start_date=start_date, end_date=end_date) + df = _fetch_paginated("daily_basic", { + "ts_code": ts_code, + "start_date": start_date, + "end_date": end_date, + }) return _df_to_records(df, _TABLE_COLUMNS["daily_basic"]) @@ -329,7 +426,11 @@ def fetch_adj_factor(start: date, end: date, ts_code: Optional[str] = None) -> I start_date = _format_date(start) end_date = _format_date(end) LOGGER.info("拉取复权因子(%s-%s,股票:%s)", start_date, end_date, ts_code or "全部") - df = client.adj_factor(ts_code=ts_code, start_date=start_date, end_date=end_date) + df = _fetch_paginated("adj_factor", { + "ts_code": ts_code, + "start_date": start_date, + "end_date": end_date, + }) return _df_to_records(df, _TABLE_COLUMNS["adj_factor"]) @@ -338,7 +439,11 @@ def fetch_suspensions(start: date, end: date, ts_code: Optional[str] = None) -> start_date = _format_date(start) end_date = _format_date(end) LOGGER.info("拉取停复牌信息(%s-%s)", start_date, end_date) - df = client.suspend_d(ts_code=ts_code, start_date=start_date, end_date=end_date) + df = _fetch_paginated("suspend_d", { + "ts_code": ts_code, + "start_date": start_date, + "end_date": end_date, + }, limit=2000) return _df_to_records(df, _TABLE_COLUMNS["suspend"]) @@ -356,7 +461,11 @@ def fetch_stk_limit(start: date, end: date, ts_code: Optional[str] = None) -> It start_date = _format_date(start) end_date = _format_date(end) LOGGER.info("拉取涨跌停价格(%s-%s)", start_date, end_date) - df = client.stk_limit(ts_code=ts_code, start_date=start_date, end_date=end_date) + df = _fetch_paginated("stk_limit", { + "ts_code": ts_code, + "start_date": start_date, + "end_date": end_date, + }) return _df_to_records(df, _TABLE_COLUMNS["stk_limit"]) @@ -412,41 +521,103 @@ def ensure_data_coverage( ts_codes: Optional[Sequence[str]] = None, include_limits: bool = True, force: bool = False, + progress_hook: Callable[[str, float], None] | None = None, ) -> None: initialize_database() start_str = _format_date(start) end_str = _format_date(end) + total_steps = 5 + (1 if include_limits else 0) + current_step = 0 + + def advance(message: str) -> None: + nonlocal current_step + current_step += 1 + progress = min(current_step / total_steps, 1.0) + if progress_hook: + progress_hook(message, progress) + LOGGER.info(message) + + advance("准备股票基础信息与交易日历") ensure_stock_basic() ensure_trade_calendar(start, end) codes = tuple(dict.fromkeys(ts_codes)) if ts_codes else tuple() expected_days = _expected_trading_days(start_str, end_str) - job = FetchJob("daily_autofill", start=start, end=end, ts_codes=codes) - if force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days): - save_records("daily", fetch_daily_bars(job)) + advance("处理日线行情数据") + if codes: + pending_codes: List[str] = [] + for code in codes: + if not force and _should_skip_range("daily", "trade_date", start, end, code): + LOGGER.info("股票 %s 的日线已覆盖 %s-%s,跳过", code, start_str, end_str) + continue + pending_codes.append(code) + if pending_codes: + job = FetchJob("daily_autofill", start=start, end=end, ts_codes=tuple(pending_codes)) + LOGGER.info("开始拉取日线行情:%s-%s(待补股票 %d 支)", start_str, end_str, len(pending_codes)) + save_records("daily", fetch_daily_bars(job)) + else: + needs_daily = force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days) + if not needs_daily: + LOGGER.info("日线数据已覆盖 %s-%s,跳过拉取", start_str, end_str) + else: + job = FetchJob("daily_autofill", start=start, end=end) + LOGGER.info("开始拉取日线行情:%s-%s", start_str, end_str) + save_records("daily", fetch_daily_bars(job)) + + date_cols = { + "daily_basic": "trade_date", + "adj_factor": "trade_date", + "stk_limit": "trade_date", + "suspend": "suspend_date", + } def _save_with_codes(table: str, fetch_fn) -> None: + date_col = date_cols.get(table, "trade_date") if codes: for code in codes: - save_records(table, fetch_fn(start, end, ts_code=code)) + if not force and _should_skip_range(table, date_col, start, end, code): + LOGGER.info("表 %s 股票 %s 已覆盖 %s-%s,跳过", table, code, start_str, end_str) + continue + LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str) + try: + rows = fetch_fn(start, end, ts_code=code) + except Exception: + LOGGER.exception("TuShare 拉取失败:table=%s code=%s", table, code) + raise + save_records(table, rows) else: - save_records(table, fetch_fn(start, end)) + needs_refresh = force + if not force: + expected = expected_days if table in {"daily_basic", "adj_factor", "stk_limit"} else 0 + needs_refresh = _range_needs_refresh(table, date_col, start_str, end_str, expected) + if not needs_refresh: + LOGGER.info("表 %s 已覆盖 %s-%s,跳过", table, start_str, end_str) + return + LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str) + try: + rows = fetch_fn(start, end) + except Exception: + LOGGER.exception("TuShare 拉取失败:table=%s code=全部", table) + raise + save_records(table, rows) - if force or _range_needs_refresh("daily_basic", "trade_date", start_str, end_str, expected_days): - _save_with_codes("daily_basic", fetch_daily_basic) + advance("处理日线基础指标数据") + _save_with_codes("daily_basic", fetch_daily_basic) - if force or _range_needs_refresh("adj_factor", "trade_date", start_str, end_str, expected_days): - _save_with_codes("adj_factor", fetch_adj_factor) + advance("处理复权因子数据") + _save_with_codes("adj_factor", fetch_adj_factor) - if include_limits and (force or _range_needs_refresh("stk_limit", "trade_date", start_str, end_str, expected_days)): + if include_limits: + advance("处理涨跌停价格数据") _save_with_codes("stk_limit", fetch_stk_limit) - if force or _range_needs_refresh("suspend", "suspend_date", start_str, end_str): - _save_with_codes("suspend", fetch_suspensions) - + advance("处理停复牌信息") + _save_with_codes("suspend", fetch_suspensions) + if progress_hook: + progress_hook("数据覆盖检查完成", 1.0) def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]]: start_str = _format_date(start) end_str = _format_date(end) @@ -500,5 +671,11 @@ def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object] def run_ingestion(job: FetchJob, include_limits: bool = True) -> None: LOGGER.info("启动 TuShare 拉取任务:%s", job.name) - ensure_data_coverage(job.start, job.end, ts_codes=job.ts_codes, include_limits=include_limits, force=True) + ensure_data_coverage( + job.start, + job.end, + ts_codes=job.ts_codes, + include_limits=include_limits, + force=True, + ) LOGGER.info("任务 %s 完成", job.name) diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 6146c20..79f5140 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -156,13 +156,29 @@ def render_tests() -> None: st.divider() days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30)) if st.button("执行开机检查"): + progress_bar = st.progress(0.0) + status_placeholder = st.empty() + log_placeholder = st.empty() + messages: list[str] = [] + + def hook(message: str, value: float) -> None: + progress_bar.progress(min(max(value, 0.0), 1.0)) + status_placeholder.write(message) + messages.append(message) + with st.spinner("正在执行开机检查..."): try: - report = run_boot_check(days=days) + report = run_boot_check(days=days, progress_hook=hook) st.success("开机检查完成,以下为数据覆盖摘要。") st.json(report.to_dict()) + if messages: + log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages)) except Exception as exc: # noqa: BLE001 st.error(f"开机检查失败:{exc}") + if messages: + log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages)) + finally: + progress_bar.progress(1.0) st.divider() st.subheader("股票行情可视化")