optimize data fetching and database connection handling
This commit is contained in:
parent
dce49058e5
commit
16112d264b
@ -2,8 +2,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
from inspect import signature
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from typing import Callable, Dict, List, Optional, Sequence
|
from typing import Callable, Dict, Iterable, List, Optional, Sequence
|
||||||
|
|
||||||
from app.data.schema import initialize_database
|
from app.data.schema import initialize_database
|
||||||
from app.utils.db import db_session
|
from app.utils.db import db_session
|
||||||
@ -231,7 +232,7 @@ def ensure_data_coverage(
|
|||||||
advance("处理指数每日指标数据")
|
advance("处理指数每日指标数据")
|
||||||
ensure_index_dailybasic(start, end)
|
ensure_index_dailybasic(start, end)
|
||||||
|
|
||||||
date_cols = {
|
date_cols: Dict[str, str] = {
|
||||||
"daily_basic": "trade_date",
|
"daily_basic": "trade_date",
|
||||||
"adj_factor": "trade_date",
|
"adj_factor": "trade_date",
|
||||||
"stk_limit": "trade_date",
|
"stk_limit": "trade_date",
|
||||||
@ -246,27 +247,69 @@ def ensure_data_coverage(
|
|||||||
"us_daily": "trade_date",
|
"us_daily": "trade_date",
|
||||||
}
|
}
|
||||||
|
|
||||||
def _save_with_codes(table: str, fetch_fn) -> None:
|
incremental_tables = {"daily_basic", "adj_factor", "stk_limit", "suspend"}
|
||||||
|
expected_tables = {"daily_basic", "adj_factor", "stk_limit"}
|
||||||
|
|
||||||
|
def _save_with_codes(
|
||||||
|
table: str,
|
||||||
|
fetch_fn,
|
||||||
|
*,
|
||||||
|
targets: Optional[Iterable[str]] = None,
|
||||||
|
) -> None:
|
||||||
date_col = date_cols.get(table, "trade_date")
|
date_col = date_cols.get(table, "trade_date")
|
||||||
if codes:
|
incremental = table in incremental_tables
|
||||||
for code in codes:
|
sig = signature(fetch_fn)
|
||||||
|
has_ts_code = "ts_code" in sig.parameters
|
||||||
|
has_skip = "skip_existing" in sig.parameters
|
||||||
|
|
||||||
|
def _call_fetch(code: Optional[str] = None):
|
||||||
|
kwargs: Dict[str, object] = {}
|
||||||
|
if has_ts_code and code is not None:
|
||||||
|
kwargs["ts_code"] = code
|
||||||
|
if has_skip:
|
||||||
|
kwargs["skip_existing"] = (not force) and incremental
|
||||||
|
return fetch_fn(start, end, **kwargs)
|
||||||
|
|
||||||
|
if targets is not None:
|
||||||
|
target_iter = list(dict.fromkeys(targets))
|
||||||
|
elif codes:
|
||||||
|
target_iter = list(codes)
|
||||||
|
else:
|
||||||
|
target_iter = []
|
||||||
|
|
||||||
|
if target_iter:
|
||||||
|
if not has_ts_code:
|
||||||
|
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
|
||||||
|
rows = _call_fetch()
|
||||||
|
save_records(table, rows)
|
||||||
|
return
|
||||||
|
for code in target_iter:
|
||||||
if not force and _should_skip_range(table, date_col, start, end, code):
|
if not force and _should_skip_range(table, date_col, start, end, code):
|
||||||
LOGGER.info("表 %s 股票 %s 已覆盖 %s-%s,跳过", table, code, start_str, end_str)
|
LOGGER.info("表 %s 股票 %s 已覆盖 %s-%s,跳过", table, code, start_str, end_str)
|
||||||
continue
|
continue
|
||||||
LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str)
|
LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str)
|
||||||
rows = fetch_fn(start, end, ts_code=code, skip_existing=not force)
|
rows = _call_fetch(code)
|
||||||
save_records(table, rows)
|
save_records(table, rows)
|
||||||
else:
|
return
|
||||||
needs_refresh = force or table == "suspend"
|
|
||||||
if not force and table != "suspend":
|
if not force:
|
||||||
expected = expected_days if table in {"daily_basic", "adj_factor", "stk_limit"} else 0
|
if table == "suspend":
|
||||||
needs_refresh = _range_needs_refresh(table, date_col, start_str, end_str, expected)
|
needs_refresh = True
|
||||||
|
else:
|
||||||
|
expected = expected_days if table in expected_tables else 0
|
||||||
|
needs_refresh = _range_needs_refresh(
|
||||||
|
table,
|
||||||
|
date_col,
|
||||||
|
start_str,
|
||||||
|
end_str,
|
||||||
|
expected_days=expected,
|
||||||
|
)
|
||||||
if not needs_refresh:
|
if not needs_refresh:
|
||||||
LOGGER.info("表 %s 已覆盖 %s-%s,跳过", table, start_str, end_str)
|
LOGGER.info("表 %s 已覆盖 %s-%s,跳过", table, start_str, end_str)
|
||||||
return
|
return
|
||||||
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
|
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
|
||||||
rows = fetch_fn(start, end, skip_existing=not force)
|
rows = _call_fetch()
|
||||||
save_records(table, rows)
|
save_records(table, rows)
|
||||||
|
|
||||||
advance("处理日线基础指标数据")
|
advance("处理日线基础指标数据")
|
||||||
_save_with_codes("daily_basic", fetch_daily_basic)
|
_save_with_codes("daily_basic", fetch_daily_basic)
|
||||||
@ -288,43 +331,19 @@ def ensure_data_coverage(
|
|||||||
save_records("fut_basic", fetch_fut_basic())
|
save_records("fut_basic", fetch_fut_basic())
|
||||||
|
|
||||||
advance("拉取指数行情数据")
|
advance("拉取指数行情数据")
|
||||||
for code in INDEX_CODES:
|
_save_with_codes("index_daily", fetch_index_daily, targets=INDEX_CODES)
|
||||||
if not force and _should_skip_range("index_daily", "trade_date", start, end, code):
|
|
||||||
LOGGER.info("指数 %s 已覆盖 %s-%s,跳过", code, start_str, end_str)
|
|
||||||
continue
|
|
||||||
save_records("index_daily", fetch_index_daily(start, end, code))
|
|
||||||
|
|
||||||
advance("拉取基金净值数据")
|
advance("拉取基金净值数据")
|
||||||
fund_targets = tuple(dict.fromkeys(ETF_CODES + FUND_CODES))
|
fund_targets = tuple(dict.fromkeys(ETF_CODES + FUND_CODES))
|
||||||
for code in fund_targets:
|
_save_with_codes("fund_nav", fetch_fund_nav, targets=fund_targets)
|
||||||
if not force and _should_skip_range("fund_nav", "nav_date", start, end, code):
|
|
||||||
LOGGER.info("基金 %s 净值已覆盖 %s-%s,跳过", code, start_str, end_str)
|
|
||||||
continue
|
|
||||||
save_records("fund_nav", fetch_fund_nav(start, end, code))
|
|
||||||
|
|
||||||
advance("拉取期货/外汇行情数据")
|
advance("拉取期货/外汇行情数据")
|
||||||
for code in FUTURE_CODES:
|
_save_with_codes("fut_daily", fetch_fut_daily, targets=FUTURE_CODES)
|
||||||
if not force and _should_skip_range("fut_daily", "trade_date", start, end, code):
|
_save_with_codes("fx_daily", fetch_fx_daily, targets=FX_CODES)
|
||||||
LOGGER.info("期货 %s 已覆盖 %s-%s,跳过", code, start_str, end_str)
|
|
||||||
continue
|
|
||||||
save_records("fut_daily", fetch_fut_daily(start, end, code))
|
|
||||||
for code in FX_CODES:
|
|
||||||
if not force and _should_skip_range("fx_daily", "trade_date", start, end, code):
|
|
||||||
LOGGER.info("外汇 %s 已覆盖 %s-%s,跳过", code, start_str, end_str)
|
|
||||||
continue
|
|
||||||
save_records("fx_daily", fetch_fx_daily(start, end, code))
|
|
||||||
|
|
||||||
advance("拉取港/美股行情数据(已暂时关闭)")
|
advance("拉取港/美股行情数据(已暂时关闭)")
|
||||||
for code in HK_CODES:
|
_save_with_codes("hk_daily", fetch_hk_daily, targets=HK_CODES)
|
||||||
if not force and _should_skip_range("hk_daily", "trade_date", start, end, code):
|
_save_with_codes("us_daily", fetch_us_daily, targets=US_CODES)
|
||||||
LOGGER.info("港股 %s 已覆盖 %s-%s,跳过", code, start_str, end_str)
|
|
||||||
continue
|
|
||||||
save_records("hk_daily", fetch_hk_daily(start, end, code))
|
|
||||||
for code in US_CODES:
|
|
||||||
if not force and _should_skip_range("us_daily", "trade_date", start, end, code):
|
|
||||||
LOGGER.info("美股 %s 已覆盖 %s-%s,跳过", code, start_str, end_str)
|
|
||||||
continue
|
|
||||||
save_records("us_daily", fetch_us_daily(start, end, code))
|
|
||||||
|
|
||||||
if progress_hook:
|
if progress_hook:
|
||||||
progress_hook("数据覆盖检查完成", 1.0)
|
progress_hook("数据覆盖检查完成", 1.0)
|
||||||
|
|||||||
@ -14,12 +14,14 @@ def get_connection(read_only: bool = False) -> sqlite3.Connection:
|
|||||||
|
|
||||||
db_path: Path = get_config().data_paths.database
|
db_path: Path = get_config().data_paths.database
|
||||||
uri = f"file:{db_path}?mode={'ro' if read_only else 'rw'}"
|
uri = f"file:{db_path}?mode={'ro' if read_only else 'rw'}"
|
||||||
|
connect_kwargs = {"timeout": 30.0}
|
||||||
if not db_path.exists() and not read_only:
|
if not db_path.exists() and not read_only:
|
||||||
# Ensure directory exists before first write.
|
# Ensure directory exists before first write.
|
||||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path, **connect_kwargs)
|
||||||
else:
|
else:
|
||||||
conn = sqlite3.connect(uri, uri=True)
|
conn = sqlite3.connect(uri, uri=True, **connect_kwargs)
|
||||||
|
conn.execute("PRAGMA busy_timeout = 30000")
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user