optimize data fetching and database connection handling
This commit is contained in:
parent
dce49058e5
commit
16112d264b
@ -2,8 +2,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from inspect import signature
|
||||
from datetime import date
|
||||
from typing import Callable, Dict, List, Optional, Sequence
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Sequence
|
||||
|
||||
from app.data.schema import initialize_database
|
||||
from app.utils.db import db_session
|
||||
@ -231,7 +232,7 @@ def ensure_data_coverage(
|
||||
advance("处理指数每日指标数据")
|
||||
ensure_index_dailybasic(start, end)
|
||||
|
||||
date_cols = {
|
||||
date_cols: Dict[str, str] = {
|
||||
"daily_basic": "trade_date",
|
||||
"adj_factor": "trade_date",
|
||||
"stk_limit": "trade_date",
|
||||
@ -246,27 +247,69 @@ def ensure_data_coverage(
|
||||
"us_daily": "trade_date",
|
||||
}
|
||||
|
||||
def _save_with_codes(table: str, fetch_fn) -> None:
|
||||
incremental_tables = {"daily_basic", "adj_factor", "stk_limit", "suspend"}
|
||||
expected_tables = {"daily_basic", "adj_factor", "stk_limit"}
|
||||
|
||||
def _save_with_codes(
|
||||
table: str,
|
||||
fetch_fn,
|
||||
*,
|
||||
targets: Optional[Iterable[str]] = None,
|
||||
) -> None:
|
||||
date_col = date_cols.get(table, "trade_date")
|
||||
if codes:
|
||||
for code in codes:
|
||||
incremental = table in incremental_tables
|
||||
sig = signature(fetch_fn)
|
||||
has_ts_code = "ts_code" in sig.parameters
|
||||
has_skip = "skip_existing" in sig.parameters
|
||||
|
||||
def _call_fetch(code: Optional[str] = None):
|
||||
kwargs: Dict[str, object] = {}
|
||||
if has_ts_code and code is not None:
|
||||
kwargs["ts_code"] = code
|
||||
if has_skip:
|
||||
kwargs["skip_existing"] = (not force) and incremental
|
||||
return fetch_fn(start, end, **kwargs)
|
||||
|
||||
if targets is not None:
|
||||
target_iter = list(dict.fromkeys(targets))
|
||||
elif codes:
|
||||
target_iter = list(codes)
|
||||
else:
|
||||
target_iter = []
|
||||
|
||||
if target_iter:
|
||||
if not has_ts_code:
|
||||
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
|
||||
rows = _call_fetch()
|
||||
save_records(table, rows)
|
||||
return
|
||||
for code in target_iter:
|
||||
if not force and _should_skip_range(table, date_col, start, end, code):
|
||||
LOGGER.info("表 %s 股票 %s 已覆盖 %s-%s,跳过", table, code, start_str, end_str)
|
||||
continue
|
||||
LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str)
|
||||
rows = fetch_fn(start, end, ts_code=code, skip_existing=not force)
|
||||
rows = _call_fetch(code)
|
||||
save_records(table, rows)
|
||||
else:
|
||||
needs_refresh = force or table == "suspend"
|
||||
if not force and table != "suspend":
|
||||
expected = expected_days if table in {"daily_basic", "adj_factor", "stk_limit"} else 0
|
||||
needs_refresh = _range_needs_refresh(table, date_col, start_str, end_str, expected)
|
||||
return
|
||||
|
||||
if not force:
|
||||
if table == "suspend":
|
||||
needs_refresh = True
|
||||
else:
|
||||
expected = expected_days if table in expected_tables else 0
|
||||
needs_refresh = _range_needs_refresh(
|
||||
table,
|
||||
date_col,
|
||||
start_str,
|
||||
end_str,
|
||||
expected_days=expected,
|
||||
)
|
||||
if not needs_refresh:
|
||||
LOGGER.info("表 %s 已覆盖 %s-%s,跳过", table, start_str, end_str)
|
||||
return
|
||||
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
|
||||
rows = fetch_fn(start, end, skip_existing=not force)
|
||||
save_records(table, rows)
|
||||
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
|
||||
rows = _call_fetch()
|
||||
save_records(table, rows)
|
||||
|
||||
advance("处理日线基础指标数据")
|
||||
_save_with_codes("daily_basic", fetch_daily_basic)
|
||||
@ -288,43 +331,19 @@ def ensure_data_coverage(
|
||||
save_records("fut_basic", fetch_fut_basic())
|
||||
|
||||
advance("拉取指数行情数据")
|
||||
for code in INDEX_CODES:
|
||||
if not force and _should_skip_range("index_daily", "trade_date", start, end, code):
|
||||
LOGGER.info("指数 %s 已覆盖 %s-%s,跳过", code, start_str, end_str)
|
||||
continue
|
||||
save_records("index_daily", fetch_index_daily(start, end, code))
|
||||
_save_with_codes("index_daily", fetch_index_daily, targets=INDEX_CODES)
|
||||
|
||||
advance("拉取基金净值数据")
|
||||
fund_targets = tuple(dict.fromkeys(ETF_CODES + FUND_CODES))
|
||||
for code in fund_targets:
|
||||
if not force and _should_skip_range("fund_nav", "nav_date", start, end, code):
|
||||
LOGGER.info("基金 %s 净值已覆盖 %s-%s,跳过", code, start_str, end_str)
|
||||
continue
|
||||
save_records("fund_nav", fetch_fund_nav(start, end, code))
|
||||
_save_with_codes("fund_nav", fetch_fund_nav, targets=fund_targets)
|
||||
|
||||
advance("拉取期货/外汇行情数据")
|
||||
for code in FUTURE_CODES:
|
||||
if not force and _should_skip_range("fut_daily", "trade_date", start, end, code):
|
||||
LOGGER.info("期货 %s 已覆盖 %s-%s,跳过", code, start_str, end_str)
|
||||
continue
|
||||
save_records("fut_daily", fetch_fut_daily(start, end, code))
|
||||
for code in FX_CODES:
|
||||
if not force and _should_skip_range("fx_daily", "trade_date", start, end, code):
|
||||
LOGGER.info("外汇 %s 已覆盖 %s-%s,跳过", code, start_str, end_str)
|
||||
continue
|
||||
save_records("fx_daily", fetch_fx_daily(start, end, code))
|
||||
_save_with_codes("fut_daily", fetch_fut_daily, targets=FUTURE_CODES)
|
||||
_save_with_codes("fx_daily", fetch_fx_daily, targets=FX_CODES)
|
||||
|
||||
advance("拉取港/美股行情数据(已暂时关闭)")
|
||||
for code in HK_CODES:
|
||||
if not force and _should_skip_range("hk_daily", "trade_date", start, end, code):
|
||||
LOGGER.info("港股 %s 已覆盖 %s-%s,跳过", code, start_str, end_str)
|
||||
continue
|
||||
save_records("hk_daily", fetch_hk_daily(start, end, code))
|
||||
for code in US_CODES:
|
||||
if not force and _should_skip_range("us_daily", "trade_date", start, end, code):
|
||||
LOGGER.info("美股 %s 已覆盖 %s-%s,跳过", code, start_str, end_str)
|
||||
continue
|
||||
save_records("us_daily", fetch_us_daily(start, end, code))
|
||||
_save_with_codes("hk_daily", fetch_hk_daily, targets=HK_CODES)
|
||||
_save_with_codes("us_daily", fetch_us_daily, targets=US_CODES)
|
||||
|
||||
if progress_hook:
|
||||
progress_hook("数据覆盖检查完成", 1.0)
|
||||
|
||||
@ -14,12 +14,14 @@ def get_connection(read_only: bool = False) -> sqlite3.Connection:
|
||||
|
||||
db_path: Path = get_config().data_paths.database
|
||||
uri = f"file:{db_path}?mode={'ro' if read_only else 'rw'}"
|
||||
connect_kwargs = {"timeout": 30.0}
|
||||
if not db_path.exists() and not read_only:
|
||||
# Ensure directory exists before first write.
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn = sqlite3.connect(db_path, **connect_kwargs)
|
||||
else:
|
||||
conn = sqlite3.connect(uri, uri=True)
|
||||
conn = sqlite3.connect(uri, uri=True, **connect_kwargs)
|
||||
conn.execute("PRAGMA busy_timeout = 30000")
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user