optimize data fetching and database connection handling

This commit is contained in:
sam 2025-10-18 10:00:03 +08:00
parent dce49058e5
commit 16112d264b
2 changed files with 67 additions and 46 deletions

View File

@ -2,8 +2,9 @@
from __future__ import annotations
import sqlite3
from inspect import signature
from datetime import date
from typing import Callable, Dict, List, Optional, Sequence
from typing import Callable, Dict, Iterable, List, Optional, Sequence
from app.data.schema import initialize_database
from app.utils.db import db_session
@ -231,7 +232,7 @@ def ensure_data_coverage(
advance("处理指数每日指标数据")
ensure_index_dailybasic(start, end)
date_cols = {
date_cols: Dict[str, str] = {
"daily_basic": "trade_date",
"adj_factor": "trade_date",
"stk_limit": "trade_date",
@ -246,26 +247,68 @@ def ensure_data_coverage(
"us_daily": "trade_date",
}
def _save_with_codes(table: str, fetch_fn) -> None:
incremental_tables = {"daily_basic", "adj_factor", "stk_limit", "suspend"}
expected_tables = {"daily_basic", "adj_factor", "stk_limit"}
def _save_with_codes(
table: str,
fetch_fn,
*,
targets: Optional[Iterable[str]] = None,
) -> None:
date_col = date_cols.get(table, "trade_date")
if codes:
for code in codes:
incremental = table in incremental_tables
sig = signature(fetch_fn)
has_ts_code = "ts_code" in sig.parameters
has_skip = "skip_existing" in sig.parameters
def _call_fetch(code: Optional[str] = None):
kwargs: Dict[str, object] = {}
if has_ts_code and code is not None:
kwargs["ts_code"] = code
if has_skip:
kwargs["skip_existing"] = (not force) and incremental
return fetch_fn(start, end, **kwargs)
if targets is not None:
target_iter = list(dict.fromkeys(targets))
elif codes:
target_iter = list(codes)
else:
target_iter = []
if target_iter:
if not has_ts_code:
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
rows = _call_fetch()
save_records(table, rows)
return
for code in target_iter:
if not force and _should_skip_range(table, date_col, start, end, code):
LOGGER.info("%s 股票 %s 已覆盖 %s-%s,跳过", table, code, start_str, end_str)
continue
LOGGER.info("拉取 %s 表数据(股票:%s%s-%s", table, code, start_str, end_str)
rows = fetch_fn(start, end, ts_code=code, skip_existing=not force)
rows = _call_fetch(code)
save_records(table, rows)
return
if not force:
if table == "suspend":
needs_refresh = True
else:
needs_refresh = force or table == "suspend"
if not force and table != "suspend":
expected = expected_days if table in {"daily_basic", "adj_factor", "stk_limit"} else 0
needs_refresh = _range_needs_refresh(table, date_col, start_str, end_str, expected)
expected = expected_days if table in expected_tables else 0
needs_refresh = _range_needs_refresh(
table,
date_col,
start_str,
end_str,
expected_days=expected,
)
if not needs_refresh:
LOGGER.info("%s 已覆盖 %s-%s,跳过", table, start_str, end_str)
return
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
rows = fetch_fn(start, end, skip_existing=not force)
rows = _call_fetch()
save_records(table, rows)
advance("处理日线基础指标数据")
@ -288,43 +331,19 @@ def ensure_data_coverage(
save_records("fut_basic", fetch_fut_basic())
advance("拉取指数行情数据")
for code in INDEX_CODES:
if not force and _should_skip_range("index_daily", "trade_date", start, end, code):
LOGGER.info("指数 %s 已覆盖 %s-%s,跳过", code, start_str, end_str)
continue
save_records("index_daily", fetch_index_daily(start, end, code))
_save_with_codes("index_daily", fetch_index_daily, targets=INDEX_CODES)
advance("拉取基金净值数据")
fund_targets = tuple(dict.fromkeys(ETF_CODES + FUND_CODES))
for code in fund_targets:
if not force and _should_skip_range("fund_nav", "nav_date", start, end, code):
LOGGER.info("基金 %s 净值已覆盖 %s-%s,跳过", code, start_str, end_str)
continue
save_records("fund_nav", fetch_fund_nav(start, end, code))
_save_with_codes("fund_nav", fetch_fund_nav, targets=fund_targets)
advance("拉取期货/外汇行情数据")
for code in FUTURE_CODES:
if not force and _should_skip_range("fut_daily", "trade_date", start, end, code):
LOGGER.info("期货 %s 已覆盖 %s-%s,跳过", code, start_str, end_str)
continue
save_records("fut_daily", fetch_fut_daily(start, end, code))
for code in FX_CODES:
if not force and _should_skip_range("fx_daily", "trade_date", start, end, code):
LOGGER.info("外汇 %s 已覆盖 %s-%s,跳过", code, start_str, end_str)
continue
save_records("fx_daily", fetch_fx_daily(start, end, code))
_save_with_codes("fut_daily", fetch_fut_daily, targets=FUTURE_CODES)
_save_with_codes("fx_daily", fetch_fx_daily, targets=FX_CODES)
advance("拉取港/美股行情数据(已暂时关闭)")
for code in HK_CODES:
if not force and _should_skip_range("hk_daily", "trade_date", start, end, code):
LOGGER.info("港股 %s 已覆盖 %s-%s,跳过", code, start_str, end_str)
continue
save_records("hk_daily", fetch_hk_daily(start, end, code))
for code in US_CODES:
if not force and _should_skip_range("us_daily", "trade_date", start, end, code):
LOGGER.info("美股 %s 已覆盖 %s-%s,跳过", code, start_str, end_str)
continue
save_records("us_daily", fetch_us_daily(start, end, code))
_save_with_codes("hk_daily", fetch_hk_daily, targets=HK_CODES)
_save_with_codes("us_daily", fetch_us_daily, targets=US_CODES)
if progress_hook:
progress_hook("数据覆盖检查完成", 1.0)

View File

@ -14,12 +14,14 @@ def get_connection(read_only: bool = False) -> sqlite3.Connection:
db_path: Path = get_config().data_paths.database
uri = f"file:{db_path}?mode={'ro' if read_only else 'rw'}"
connect_kwargs = {"timeout": 30.0}
if not db_path.exists() and not read_only:
# Ensure directory exists before first write.
db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(db_path)
conn = sqlite3.connect(db_path, **connect_kwargs)
else:
conn = sqlite3.connect(uri, uri=True)
conn = sqlite3.connect(uri, uri=True, **connect_kwargs)
conn.execute("PRAGMA busy_timeout = 30000")
conn.row_factory = sqlite3.Row
return conn