diff --git a/app/ingest/coverage.py b/app/ingest/coverage.py index 08c0ee2..33e439a 100644 --- a/app/ingest/coverage.py +++ b/app/ingest/coverage.py @@ -2,8 +2,9 @@ from __future__ import annotations import sqlite3 +from inspect import signature from datetime import date -from typing import Callable, Dict, List, Optional, Sequence +from typing import Callable, Dict, Iterable, List, Optional, Sequence from app.data.schema import initialize_database from app.utils.db import db_session @@ -231,7 +232,7 @@ def ensure_data_coverage( advance("处理指数每日指标数据") ensure_index_dailybasic(start, end) - date_cols = { + date_cols: Dict[str, str] = { "daily_basic": "trade_date", "adj_factor": "trade_date", "stk_limit": "trade_date", @@ -246,27 +247,69 @@ def ensure_data_coverage( "us_daily": "trade_date", } - def _save_with_codes(table: str, fetch_fn) -> None: + incremental_tables = {"daily_basic", "adj_factor", "stk_limit", "suspend"} + expected_tables = {"daily_basic", "adj_factor", "stk_limit"} + + def _save_with_codes( + table: str, + fetch_fn, + *, + targets: Optional[Iterable[str]] = None, + ) -> None: date_col = date_cols.get(table, "trade_date") - if codes: - for code in codes: + incremental = table in incremental_tables + sig = signature(fetch_fn) + has_ts_code = "ts_code" in sig.parameters + has_skip = "skip_existing" in sig.parameters + + def _call_fetch(code: Optional[str] = None): + kwargs: Dict[str, object] = {} + if has_ts_code and code is not None: + kwargs["ts_code"] = code + if has_skip: + kwargs["skip_existing"] = (not force) and incremental + return fetch_fn(start, end, **kwargs) + + if targets is not None: + target_iter = list(dict.fromkeys(targets)) + elif codes: + target_iter = list(codes) + else: + target_iter = [] + + if target_iter: + if not has_ts_code: + LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str) + rows = _call_fetch() + save_records(table, rows) + return + for code in target_iter: if not force and _should_skip_range(table, date_col, start, end, code): LOGGER.info("表 %s 股票 %s 已覆盖 %s-%s,跳过", table, code, start_str, end_str) continue LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str) - rows = fetch_fn(start, end, ts_code=code, skip_existing=not force) + rows = _call_fetch(code) save_records(table, rows) - else: - needs_refresh = force or table == "suspend" - if not force and table != "suspend": - expected = expected_days if table in {"daily_basic", "adj_factor", "stk_limit"} else 0 - needs_refresh = _range_needs_refresh(table, date_col, start_str, end_str, expected) + return + + if not force: + if table == "suspend": + needs_refresh = True + else: + expected = expected_days if table in expected_tables else 0 + needs_refresh = _range_needs_refresh( + table, + date_col, + start_str, + end_str, + expected_days=expected, + ) if not needs_refresh: LOGGER.info("表 %s 已覆盖 %s-%s,跳过", table, start_str, end_str) return - LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str) - rows = fetch_fn(start, end, skip_existing=not force) - save_records(table, rows) + LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str) + rows = _call_fetch() + save_records(table, rows) advance("处理日线基础指标数据") _save_with_codes("daily_basic", fetch_daily_basic) @@ -288,43 +331,19 @@ def ensure_data_coverage( save_records("fut_basic", fetch_fut_basic()) advance("拉取指数行情数据") - for code in INDEX_CODES: - if not force and _should_skip_range("index_daily", "trade_date", start, end, code): - LOGGER.info("指数 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) - continue - save_records("index_daily", fetch_index_daily(start, end, code)) + _save_with_codes("index_daily", fetch_index_daily, targets=INDEX_CODES) advance("拉取基金净值数据") fund_targets = tuple(dict.fromkeys(ETF_CODES + FUND_CODES)) - for code in fund_targets: - if not force and _should_skip_range("fund_nav", "nav_date", start, end, code): - LOGGER.info("基金 %s 净值已覆盖 %s-%s,跳过", code, start_str, end_str) - continue - save_records("fund_nav", fetch_fund_nav(start, end, code)) + _save_with_codes("fund_nav", fetch_fund_nav, targets=fund_targets) advance("拉取期货/外汇行情数据") - for code in FUTURE_CODES: - if not force and _should_skip_range("fut_daily", "trade_date", start, end, code): - LOGGER.info("期货 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) - continue - save_records("fut_daily", fetch_fut_daily(start, end, code)) - for code in FX_CODES: - if not force and _should_skip_range("fx_daily", "trade_date", start, end, code): - LOGGER.info("外汇 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) - continue - save_records("fx_daily", fetch_fx_daily(start, end, code)) + _save_with_codes("fut_daily", fetch_fut_daily, targets=FUTURE_CODES) + _save_with_codes("fx_daily", fetch_fx_daily, targets=FX_CODES) advance("拉取港/美股行情数据(已暂时关闭)") - for code in HK_CODES: - if not force and _should_skip_range("hk_daily", "trade_date", start, end, code): - LOGGER.info("港股 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) - continue - save_records("hk_daily", fetch_hk_daily(start, end, code)) - for code in US_CODES: - if not force and _should_skip_range("us_daily", "trade_date", start, end, code): - LOGGER.info("美股 %s 已覆盖 %s-%s,跳过", code, start_str, end_str) - continue - save_records("us_daily", fetch_us_daily(start, end, code)) + _save_with_codes("hk_daily", fetch_hk_daily, targets=HK_CODES) + _save_with_codes("us_daily", fetch_us_daily, targets=US_CODES) if progress_hook: progress_hook("数据覆盖检查完成", 1.0) diff --git a/app/utils/db.py b/app/utils/db.py index 32ee0d9..3f9ef1d 100644 --- a/app/utils/db.py +++ b/app/utils/db.py @@ -14,12 +14,14 @@ def get_connection(read_only: bool = False) -> sqlite3.Connection: db_path: Path = get_config().data_paths.database uri = f"file:{db_path}?mode={'ro' if read_only else 'rw'}" + connect_kwargs = {"timeout": 30.0} if not db_path.exists() and not read_only: # Ensure directory exists before first write. db_path.parent.mkdir(parents=True, exist_ok=True) - conn = sqlite3.connect(db_path) + conn = sqlite3.connect(db_path, **connect_kwargs) else: - conn = sqlite3.connect(uri, uri=True) + conn = sqlite3.connect(uri, uri=True, **connect_kwargs) + conn.execute("PRAGMA busy_timeout = 30000") conn.row_factory = sqlite3.Row return conn