"""Data coverage orchestration separated from TuShare API calls.""" from __future__ import annotations import sqlite3 from datetime import date from typing import Callable, Dict, List, Optional, Sequence from app.data.schema import initialize_database from app.utils.db import db_session from app.utils.logging import get_logger from .api_client import ( ETF_CODES, FUND_CODES, FUTURE_CODES, FX_CODES, HK_CODES, INDEX_CODES, LOG_EXTRA, US_CODES, _expected_trading_days, _format_date, _listing_window, ensure_stock_basic, ensure_trade_calendar, fetch_adj_factor, fetch_daily_basic, fetch_daily_bars, fetch_fund_basic, fetch_fund_nav, fetch_fut_basic, fetch_fut_daily, fetch_fx_daily, fetch_hk_daily, fetch_index_basic, fetch_index_daily, fetch_index_dailybasic, fetch_index_weight, fetch_suspensions, fetch_stk_limit, fetch_trade_calendar, fetch_us_daily, save_records, ) LOGGER = get_logger(__name__) def _range_stats( table: str, date_col: str, start_str: str, end_str: str, ts_code: str | None = None, ) -> Dict[str, Optional[str]]: sql = ( f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d, " f"COUNT(DISTINCT {date_col}) AS distinct_days FROM {table} " f"WHERE {date_col} BETWEEN ? AND ?" ) params: List[object] = [start_str, end_str] if ts_code: sql += " AND ts_code = ?" params.append(ts_code) try: with db_session(read_only=True) as conn: row = conn.execute(sql, tuple(params)).fetchone() except sqlite3.OperationalError: return {"min": None, "max": None, "distinct": 0} return { "min": row["min_d"] if row else None, "max": row["max_d"] if row else None, "distinct": row["distinct_days"] if row else 0, } def _range_needs_refresh( table: str, date_col: str, start_str: str, end_str: str, expected_days: int = 0, **filters: object, ) -> bool: ts_code = filters.get("ts_code") or filters.get("index_code") stats = _range_stats(table, date_col, start_str, end_str, ts_code=ts_code) # type: ignore[arg-type] if stats["min"] is None or stats["max"] is None: return True if stats["min"] > start_str or stats["max"] < end_str: return True if expected_days and (stats["distinct"] or 0) < expected_days: return True return False def _should_skip_range( table: str, date_col: str, start: date, end: date, ts_code: str | None = None, ) -> bool: start_str = _format_date(start) end_str = _format_date(end) effective_start = start_str effective_end = end_str if ts_code: list_date, delist_date = _listing_window(ts_code) if list_date: effective_start = max(effective_start, list_date) if delist_date: effective_end = min(effective_end, delist_date) if effective_start > effective_end: LOGGER.debug( "股票 %s 在目标区间之外,跳过补数", ts_code, extra=LOG_EXTRA, ) return True stats = _range_stats(table, date_col, effective_start, effective_end, ts_code=ts_code) else: stats = _range_stats(table, date_col, effective_start, effective_end) if stats["min"] is None or stats["max"] is None: return False if stats["min"] > effective_start or stats["max"] < effective_end: return False if ts_code is None: expected_days = _expected_trading_days(effective_start, effective_end) if expected_days and (stats["distinct"] or 0) < expected_days: return False return True def ensure_index_weights(start: date, end: date, index_codes: Optional[Sequence[str]] = None) -> None: if index_codes is None: index_codes = [code for code in INDEX_CODES if code.endswith(".SH") or code.endswith(".SZ")] for index_code in index_codes: start_str = _format_date(start) end_str = _format_date(end) if _range_needs_refresh("index_weight", "trade_date", start_str, end_str, index_code=index_code): LOGGER.info("指数 %s 的成分股权重数据不完整,开始拉取 %s-%s", index_code, start_str, end_str) save_records("index_weight", fetch_index_weight(start, end, index_code)) else: LOGGER.info("指数 %s 的成分股权重数据已完整,跳过", index_code) def ensure_index_dailybasic(start: date, end: date, index_codes: Optional[Sequence[str]] = None) -> None: if index_codes is None: index_codes = [code for code in INDEX_CODES if code.endswith(".SH") or code.endswith(".SZ")] for index_code in index_codes: start_str = _format_date(start) end_str = _format_date(end) if _range_needs_refresh("index_dailybasic", "trade_date", start_str, end_str, ts_code=index_code): LOGGER.info("指数 %s 的每日指标数据不完整,开始拉取 %s-%s", index_code, start_str, end_str) save_records("index_dailybasic", fetch_index_dailybasic(start, end, index_code)) else: LOGGER.info("指数 %s 的每日指标数据已完整,跳过", index_code) def ensure_data_coverage( start: date, end: date, ts_codes: Optional[Sequence[str]] = None, include_limits: bool = True, include_extended: bool = True, force: bool = False, progress_hook: Callable[[str, float], None] | None = None, ) -> None: initialize_database() start_str = _format_date(start) end_str = _format_date(end) extra_steps = 0 if include_limits: extra_steps += 1 if include_extended: extra_steps += 4 total_steps = 5 + extra_steps current_step = 0 def advance(message: str) -> None: nonlocal current_step current_step += 1 progress = min(current_step / total_steps, 1.0) if progress_hook: progress_hook(message, progress) LOGGER.info(message, extra=LOG_EXTRA) advance("准备股票基础信息与交易日历") ensure_stock_basic() ensure_trade_calendar(start, end) codes = tuple(dict.fromkeys(ts_codes)) if ts_codes else tuple() expected_days = _expected_trading_days(start_str, end_str) advance("处理日线行情数据") if codes: pending_codes: List[str] = [] for code in codes: if not force and _should_skip_range("daily", "trade_date", start, end, code): LOGGER.info("股票 %s 的日线已覆盖 %s-%s,跳过", code, start_str, end_str) continue pending_codes.append(code) if pending_codes: LOGGER.info("开始拉取日线行情:%s-%s(待补股票 %d 支)", start_str, end_str, len(pending_codes)) save_records( "daily", fetch_daily_bars(start, end, pending_codes, skip_existing=not force), ) else: needs_daily = force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days) if not needs_daily: LOGGER.info("日线数据已覆盖 %s-%s,跳过拉取", start_str, end_str) else: LOGGER.info("开始拉取日线行情:%s-%s", start_str, end_str) save_records( "daily", fetch_daily_bars(start, end, skip_existing=not force), ) advance("处理指数成分股权重数据") ensure_index_weights(start, end) advance("处理指数每日指标数据") ensure_index_dailybasic(start, end) date_cols = { "daily_basic": "trade_date", "adj_factor": "trade_date", "stk_limit": "trade_date", "suspend": "suspend_date", "index_daily": "trade_date", "index_dailybasic": "trade_date", "index_weight": "trade_date", "fund_nav": "nav_date", "fut_daily": "trade_date", "fx_daily": "trade_date", "hk_daily": "trade_date", "us_daily": "trade_date", } def _save_with_codes(table: str, fetch_fn) -> None: date_col = date_cols.get(table, "trade_date") if codes: for code in codes: if not force and _should_skip_range(table, date_col, start, end, code): LOGGER.info("表 %s 股票 %s 已覆盖 %s-%s,跳过", table, code, start_str, end_str) continue LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str) rows = fetch_fn(start, end, ts_code=code, skip_existing=not force) save_records(table, rows) else: needs_refresh = force or table == "suspend" if not force and table != "suspend": expected = expected_days if table in {"daily_basic", "adj_factor", "stk_limit"} else 0 needs_refresh = _range_needs_refresh(table, date_col, start_str, end_str, expected) if not needs_refresh: LOGGER.info("表 %s 已覆盖 %s-%s,跳过", table, start_str, end_str) return LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str) rows = fetch_fn(start, end, skip_existing=not force) save_records(table, rows) advance("处理日线基础指标数据") _save_with_codes("daily_basic", fetch_daily_basic) advance("处理复权因子数据") _save_with_codes("adj_factor", fetch_adj_factor) if include_limits: advance("处理涨跌停价格数据") _save_with_codes("stk_limit", fetch_stk_limit) advance("处理停复牌信息") _save_with_codes("suspend", fetch_suspensions) if include_extended: advance("同步指数/基金/期货基础信息") save_records("index_basic", fetch_index_basic()) save_records("fund_basic", fetch_fund_basic()) save_records("fut_basic", fetch_fut_basic()) advance("拉取指数行情数据") for code in INDEX_CODES: if not force and _should_skip_range("index_daily", "trade_date", start, end, code): LOGGER.info("指数 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) continue save_records("index_daily", fetch_index_daily(start, end, code)) advance("拉取基金净值数据") fund_targets = tuple(dict.fromkeys(ETF_CODES + FUND_CODES)) for code in fund_targets: if not force and _should_skip_range("fund_nav", "nav_date", start, end, code): LOGGER.info("基金 %s 净值已覆盖 %s-%s,跳过", code, start_str, end_str) continue save_records("fund_nav", fetch_fund_nav(start, end, code)) advance("拉取期货/外汇行情数据") for code in FUTURE_CODES: if not force and _should_skip_range("fut_daily", "trade_date", start, end, code): LOGGER.info("期货 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) continue save_records("fut_daily", fetch_fut_daily(start, end, code)) for code in FX_CODES: if not force and _should_skip_range("fx_daily", "trade_date", start, end, code): LOGGER.info("外汇 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) continue save_records("fx_daily", fetch_fx_daily(start, end, code)) advance("拉取港/美股行情数据(已暂时关闭)") for code in HK_CODES: if not force and _should_skip_range("hk_daily", "trade_date", start, end, code): LOGGER.info("港股 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) continue save_records("hk_daily", fetch_hk_daily(start, end, code)) for code in US_CODES: if not force and _should_skip_range("us_daily", "trade_date", start, end, code): LOGGER.info("美股 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) continue save_records("us_daily", fetch_us_daily(start, end, code)) if progress_hook: progress_hook("数据覆盖检查完成", 1.0) def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]]: start_str = _format_date(start) end_str = _format_date(end) expected_days = _expected_trading_days(start_str, end_str) coverage: Dict[str, Dict[str, object]] = { "period": { "start": start_str, "end": end_str, "expected_trading_days": expected_days, } } def add_table(name: str, date_col: str, require_days: bool = True) -> None: stats = _range_stats(name, date_col, start_str, end_str) coverage[name] = { "min": stats["min"], "max": stats["max"], "distinct_days": stats["distinct"], "meets_expectation": ( stats["min"] is not None and stats["max"] is not None and stats["min"] <= start_str and stats["max"] >= end_str and ((not require_days) or (stats["distinct"] or 0) >= expected_days) ), } add_table("daily", "trade_date") add_table("daily_basic", "trade_date") add_table("adj_factor", "trade_date") add_table("stk_limit", "trade_date") add_table("suspend", "suspend_date", require_days=False) add_table("index_daily", "trade_date") add_table("fund_nav", "nav_date", require_days=False) add_table("fut_daily", "trade_date", require_days=False) add_table("fx_daily", "trade_date", require_days=False) add_table("hk_daily", "trade_date", require_days=False) add_table("us_daily", "trade_date", require_days=False) with db_session(read_only=True) as conn: stock_tot = conn.execute("SELECT COUNT(*) AS cnt FROM stock_basic").fetchone() stock_sse = conn.execute( "SELECT COUNT(*) AS cnt FROM stock_basic WHERE exchange = 'SSE' AND list_status = 'L'" ).fetchone() stock_szse = conn.execute( "SELECT COUNT(*) AS cnt FROM stock_basic WHERE exchange = 'SZSE' AND list_status = 'L'" ).fetchone() coverage["stock_basic"] = { "total": stock_tot["cnt"] if stock_tot else 0, "sse_listed": stock_sse["cnt"] if stock_sse else 0, "szse_listed": stock_szse["cnt"] if stock_szse else 0, } return coverage __all__ = [ "collect_data_coverage", "ensure_data_coverage", "ensure_index_dailybasic", "ensure_index_weights", ]