This commit is contained in:
sam 2025-09-27 11:59:18 +08:00
parent 774b68de99
commit 6c9c8e3140
4 changed files with 174 additions and 76 deletions

View File

@ -1,15 +1,16 @@
"""数据覆盖开机检查器。""" """数据覆盖开机检查器。"""
from __future__ import annotations from __future__ import annotations
import logging
from dataclasses import dataclass from dataclasses import dataclass
from datetime import date, timedelta from datetime import date, timedelta
from typing import Callable, Dict from typing import Callable, Dict
from app.data.schema import initialize_database from app.data.schema import initialize_database
from app.ingest.tushare import collect_data_coverage, ensure_data_coverage from app.ingest.tushare import collect_data_coverage, ensure_data_coverage
from app.utils.config import get_config
from app.utils.logging import get_logger
LOGGER = logging.getLogger(__name__) LOGGER = get_logger(__name__)
@dataclass @dataclass
@ -40,6 +41,7 @@ def run_boot_check(
days: int = 365, days: int = 365,
auto_fetch: bool = True, auto_fetch: bool = True,
progress_hook: Callable[[str, float], None] | None = None, progress_hook: Callable[[str, float], None] | None = None,
force_refresh: bool | None = None,
) -> CoverageReport: ) -> CoverageReport:
"""执行开机自检,必要时自动补数据。""" """执行开机自检,必要时自动补数据。"""
@ -47,8 +49,17 @@ def run_boot_check(
start, end = _default_window(days) start, end = _default_window(days)
LOGGER.info("开机检查覆盖窗口:%s%s", start, end) LOGGER.info("开机检查覆盖窗口:%s%s", start, end)
refresh = force_refresh
if refresh is None:
refresh = get_config().force_refresh
if auto_fetch: if auto_fetch:
ensure_data_coverage(start, end, progress_hook=progress_hook) ensure_data_coverage(
start,
end,
force=refresh,
progress_hook=progress_hook,
)
coverage = collect_data_coverage(start, end) coverage = collect_data_coverage(start, end)

View File

@ -1,7 +1,6 @@
"""TuShare 数据拉取与数据覆盖检查工具。""" """TuShare 数据拉取与数据覆盖检查工具。"""
from __future__ import annotations from __future__ import annotations
import logging
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from datetime import date from datetime import date
@ -16,9 +15,21 @@ except ImportError: # pragma: no cover - 运行时提示
from app.utils.config import get_config from app.utils.config import get_config
from app.utils.db import db_session from app.utils.db import db_session
from app.data.schema import initialize_database
from app.utils.logging import get_logger
def _existing_date_range(table: str, date_col: str, ts_code: str | None = None) -> Tuple[str | None, str | None]: LOGGER = get_logger(__name__)
API_DEFAULT_LIMIT = 5000
LOG_EXTRA = {"stage": "data_ingest"}
def _existing_date_range(
table: str,
date_col: str,
ts_code: str | None = None,
) -> Tuple[str | None, str | None]:
query = f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d FROM {table}" query = f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d FROM {table}"
params: Tuple = () params: Tuple = ()
if ts_code: if ts_code:
@ -31,13 +42,11 @@ def _existing_date_range(table: str, date_col: str, ts_code: str | None = None)
return row["min_d"], row["max_d"] return row["min_d"], row["max_d"]
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
from app.data.schema import initialize_database if df is None or df.empty:
return []
LOGGER = logging.getLogger(__name__) reindexed = df.reindex(columns=allowed_cols)
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
API_DEFAULT_LIMIT = 5000
LOG_EXTRA = {"stage": "data_ingest"}
def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None = None) -> pd.DataFrame: def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None = None) -> pd.DataFrame:
@ -294,13 +303,6 @@ def _format_date(value: date) -> str:
return value.strftime("%Y%m%d") return value.strftime("%Y%m%d")
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
if df is None or df.empty:
return []
reindexed = df.reindex(columns=allowed_cols)
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str]: def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str]:
start_str = _format_date(start) start_str = _format_date(start)
end_str = _format_date(end) end_str = _format_date(end)
@ -405,34 +407,66 @@ def fetch_stock_basic(exchange: Optional[str] = None, list_status: str = "L") ->
return _df_to_records(df, _TABLE_COLUMNS["stock_basic"]) return _df_to_records(df, _TABLE_COLUMNS["stock_basic"])
def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]: def fetch_daily_bars(job: FetchJob, skip_existing: bool = True) -> Iterable[Dict]:
client = _ensure_client() client = _ensure_client()
start_date = _format_date(job.start)
end_date = _format_date(job.end)
frames: List[pd.DataFrame] = [] frames: List[pd.DataFrame] = []
if job.granularity != "daily": if job.granularity != "daily":
raise ValueError(f"暂不支持的粒度:{job.granularity}") raise ValueError(f"暂不支持的粒度:{job.granularity}")
params = { trade_dates = _load_trade_dates(job.start, job.end)
"start_date": start_date, if not trade_dates:
"end_date": end_date, LOGGER.info("本地交易日历缺失,尝试补全后再拉取日线行情", extra=LOG_EXTRA)
} ensure_trade_calendar(job.start, job.end)
trade_dates = _load_trade_dates(job.start, job.end)
if job.ts_codes: if job.ts_codes:
for code in job.ts_codes: for code in job.ts_codes:
LOGGER.info("拉取 %s 的日线行情(%s-%s", code, start_date, end_date) for trade_date in trade_dates:
df = _fetch_paginated("daily", {**params, "ts_code": code}) if skip_existing and _record_exists("daily", "trade_date", trade_date, code):
if not df.empty: LOGGER.debug(
frames.append(df) "日线数据已存在,跳过 %s %s",
code,
trade_date,
extra=LOG_EXTRA,
)
continue
LOGGER.debug(
"按交易日拉取日线行情code=%s trade_date=%s",
code,
trade_date,
extra=LOG_EXTRA,
)
LOGGER.info(
"交易日拉取请求endpoint=daily code=%s trade_date=%s",
code,
trade_date,
extra=LOG_EXTRA,
)
df = _fetch_paginated(
"daily",
{
"trade_date": trade_date,
"ts_code": code,
},
)
if not df.empty:
frames.append(df)
else: else:
trade_dates = _load_trade_dates(job.start, job.end)
if not trade_dates:
LOGGER.info("本地交易日历缺失,尝试补全后再拉取日线行情")
ensure_trade_calendar(job.start, job.end)
trade_dates = _load_trade_dates(job.start, job.end)
for trade_date in trade_dates: for trade_date in trade_dates:
LOGGER.debug("按交易日拉取日线行情:%s", trade_date) if skip_existing and _record_exists("daily", "trade_date", trade_date):
LOGGER.debug(
"日线数据已存在,跳过交易日 %s",
trade_date,
extra=LOG_EXTRA,
)
continue
LOGGER.debug("按交易日拉取日线行情:%s", trade_date, extra=LOG_EXTRA)
LOGGER.info(
"交易日拉取请求endpoint=daily trade_date=%s",
trade_date,
extra=LOG_EXTRA,
)
df = _fetch_paginated("daily", {"trade_date": trade_date}) df = _fetch_paginated("daily", {"trade_date": trade_date})
if not df.empty: if not df.empty:
frames.append(df) frames.append(df)
@ -460,33 +494,25 @@ def fetch_daily_basic(
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
if ts_code:
df = _fetch_paginated(
"daily_basic",
{
"ts_code": ts_code,
"start_date": start_date,
"end_date": end_date,
},
)
return _df_to_records(df, _TABLE_COLUMNS["daily_basic"])
trade_dates = _load_trade_dates(start, end) trade_dates = _load_trade_dates(start, end)
frames: List[pd.DataFrame] = [] frames: List[pd.DataFrame] = []
for trade_date in trade_dates: for trade_date in trade_dates:
if skip_existing and _record_exists("daily_basic", "trade_date", trade_date): if skip_existing and _record_exists("daily_basic", "trade_date", trade_date, ts_code):
LOGGER.info( LOGGER.info(
"日线基础指标已存在,跳过交易日 %s", "日线基础指标已存在,跳过交易日 %s",
trade_date, trade_date,
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
continue continue
LOGGER.debug( params = {"trade_date": trade_date}
"按交易日拉取日线基础指标:%s", if ts_code:
trade_date, params["ts_code"] = ts_code
LOGGER.info(
"交易日拉取请求endpoint=daily_basic params=%s",
params,
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
df = _fetch_paginated("daily_basic", {"trade_date": trade_date}) df = _fetch_paginated("daily_basic", params)
if not df.empty: if not df.empty:
frames.append(df) frames.append(df)
@ -528,7 +554,7 @@ def fetch_adj_factor(
params = {"trade_date": trade_date} params = {"trade_date": trade_date}
if ts_code: if ts_code:
params["ts_code"] = ts_code params["ts_code"] = ts_code
LOGGER.debug("按交易日拉取复权因子:%s", params, extra=LOG_EXTRA) LOGGER.info("交易日拉取请求endpoint=adj_factor params=%s", params, extra=LOG_EXTRA)
df = _fetch_paginated("adj_factor", params) df = _fetch_paginated("adj_factor", params)
if not df.empty: if not df.empty:
frames.append(df) frames.append(df)
@ -540,17 +566,40 @@ def fetch_adj_factor(
return _df_to_records(merged, _TABLE_COLUMNS["adj_factor"]) return _df_to_records(merged, _TABLE_COLUMNS["adj_factor"])
def fetch_suspensions(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]: def fetch_suspensions(
start: date,
end: date,
ts_code: Optional[str] = None,
skip_existing: bool = True,
) -> Iterable[Dict]:
client = _ensure_client() client = _ensure_client()
start_date = _format_date(start) start_date = _format_date(start)
end_date = _format_date(end) end_date = _format_date(end)
LOGGER.info("拉取停复牌信息(%s-%s", start_date, end_date) LOGGER.info("拉取停复牌信息(%s-%s", start_date, end_date, extra=LOG_EXTRA)
df = _fetch_paginated("suspend_d", { trade_dates = _load_trade_dates(start, end)
"ts_code": ts_code, frames: List[pd.DataFrame] = []
"start_date": start_date, for trade_date in trade_dates:
"end_date": end_date, if skip_existing and _record_exists("suspend", "suspend_date", trade_date, ts_code):
}, limit=2000) LOGGER.debug(
return _df_to_records(df, _TABLE_COLUMNS["suspend"]) "停复牌信息已存在,跳过 %s %s",
ts_code or "ALL",
trade_date,
extra=LOG_EXTRA,
)
continue
params = {"trade_date": trade_date}
if ts_code:
params["ts_code"] = ts_code
LOGGER.info("交易日拉取请求endpoint=suspend_d params=%s", params, extra=LOG_EXTRA)
df = _fetch_paginated("suspend_d", params, limit=2000)
if not df.empty:
frames.append(df)
if not frames:
return []
merged = pd.concat(frames, ignore_index=True)
return _df_to_records(merged, _TABLE_COLUMNS["suspend"])
def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]: def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]:
@ -562,17 +611,40 @@ def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Itera
return _df_to_records(df, _TABLE_COLUMNS["trade_calendar"]) return _df_to_records(df, _TABLE_COLUMNS["trade_calendar"])
def fetch_stk_limit(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]: def fetch_stk_limit(
start: date,
end: date,
ts_code: Optional[str] = None,
skip_existing: bool = True,
) -> Iterable[Dict]:
client = _ensure_client() client = _ensure_client()
start_date = _format_date(start) start_date = _format_date(start)
end_date = _format_date(end) end_date = _format_date(end)
LOGGER.info("拉取涨跌停价格(%s-%s", start_date, end_date) LOGGER.info("拉取涨跌停价格(%s-%s", start_date, end_date, extra=LOG_EXTRA)
df = _fetch_paginated("stk_limit", { trade_dates = _load_trade_dates(start, end)
"ts_code": ts_code, frames: List[pd.DataFrame] = []
"start_date": start_date, for trade_date in trade_dates:
"end_date": end_date, if skip_existing and _record_exists("stk_limit", "trade_date", trade_date, ts_code):
}) LOGGER.debug(
return _df_to_records(df, _TABLE_COLUMNS["stk_limit"]) "涨跌停数据已存在,跳过 %s %s",
ts_code or "ALL",
trade_date,
extra=LOG_EXTRA,
)
continue
params = {"trade_date": trade_date}
if ts_code:
params["ts_code"] = ts_code
LOGGER.info("交易日拉取请求endpoint=stk_limit params=%s", params, extra=LOG_EXTRA)
df = _fetch_paginated("stk_limit", params)
if not df.empty:
frames.append(df)
if not frames:
return []
merged = pd.concat(frames, ignore_index=True)
return _df_to_records(merged, _TABLE_COLUMNS["stk_limit"])
def save_records(table: str, rows: Iterable[Dict]) -> None: def save_records(table: str, rows: Iterable[Dict]) -> None:
@ -662,7 +734,7 @@ def ensure_data_coverage(
if pending_codes: if pending_codes:
job = FetchJob("daily_autofill", start=start, end=end, ts_codes=tuple(pending_codes)) job = FetchJob("daily_autofill", start=start, end=end, ts_codes=tuple(pending_codes))
LOGGER.info("开始拉取日线行情:%s-%s(待补股票 %d 支)", start_str, end_str, len(pending_codes)) LOGGER.info("开始拉取日线行情:%s-%s(待补股票 %d 支)", start_str, end_str, len(pending_codes))
save_records("daily", fetch_daily_bars(job)) save_records("daily", fetch_daily_bars(job, skip_existing=not force))
else: else:
needs_daily = force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days) needs_daily = force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days)
if not needs_daily: if not needs_daily:
@ -670,7 +742,7 @@ def ensure_data_coverage(
else: else:
job = FetchJob("daily_autofill", start=start, end=end) job = FetchJob("daily_autofill", start=start, end=end)
LOGGER.info("开始拉取日线行情:%s-%s", start_str, end_str) LOGGER.info("开始拉取日线行情:%s-%s", start_str, end_str)
save_records("daily", fetch_daily_bars(job)) save_records("daily", fetch_daily_bars(job, skip_existing=not force))
date_cols = { date_cols = {
"daily_basic": "trade_date", "daily_basic": "trade_date",
@ -689,7 +761,7 @@ def ensure_data_coverage(
LOGGER.info("拉取 %s 表数据(股票:%s%s-%s", table, code, start_str, end_str) LOGGER.info("拉取 %s 表数据(股票:%s%s-%s", table, code, start_str, end_str)
try: try:
kwargs = {"ts_code": code} kwargs = {"ts_code": code}
if fetch_fn in (fetch_daily_basic, fetch_adj_factor): if fetch_fn in (fetch_daily_basic, fetch_adj_factor, fetch_suspensions, fetch_stk_limit):
kwargs["skip_existing"] = not force kwargs["skip_existing"] = not force
rows = fetch_fn(start, end, **kwargs) rows = fetch_fn(start, end, **kwargs)
except Exception: except Exception:
@ -707,7 +779,7 @@ def ensure_data_coverage(
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str) LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
try: try:
kwargs = {} kwargs = {}
if fetch_fn in (fetch_daily_basic, fetch_adj_factor): if fetch_fn in (fetch_daily_basic, fetch_adj_factor, fetch_suspensions, fetch_stk_limit):
kwargs["skip_existing"] = not force kwargs["skip_existing"] = not force
rows = fetch_fn(start, end, **kwargs) rows = fetch_fn(start, end, **kwargs)
except Exception: except Exception:

View File

@ -114,9 +114,10 @@ def render_settings() -> None:
st.header("数据与设置") st.header("数据与设置")
cfg = get_config() cfg = get_config()
token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password") token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password")
if st.button("保存 Token"):
if st.button("保存设置"):
cfg.tushare_token = token.strip() or None cfg.tushare_token = token.strip() or None
st.success("TuShare Token 已更新,仅保存在当前会话") st.success("设置已保存,仅在当前会话生效")
st.write("新闻源开关与数据库备份将在此配置。") st.write("新闻源开关与数据库备份将在此配置。")
@ -155,6 +156,15 @@ def render_tests() -> None:
st.divider() st.divider()
days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30)) days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30))
cfg = get_config()
force_refresh = st.checkbox(
"强制刷新数据(关闭增量跳过)",
value=cfg.force_refresh,
help="勾选后将重新拉取所选区间全部数据",
)
if force_refresh != cfg.force_refresh:
cfg.force_refresh = force_refresh
if st.button("执行开机检查"): if st.button("执行开机检查"):
progress_bar = st.progress(0.0) progress_bar = st.progress(0.0)
status_placeholder = st.empty() status_placeholder = st.empty()
@ -168,7 +178,11 @@ def render_tests() -> None:
with st.spinner("正在执行开机检查..."): with st.spinner("正在执行开机检查..."):
try: try:
report = run_boot_check(days=days, progress_hook=hook) report = run_boot_check(
days=days,
progress_hook=hook,
force_refresh=force_refresh,
)
st.success("开机检查完成,以下为数据覆盖摘要。") st.success("开机检查完成,以下为数据覆盖摘要。")
st.json(report.to_dict()) st.json(report.to_dict())
if messages: if messages:

View File

@ -54,6 +54,7 @@ class AppConfig:
decision_method: str = "nash" decision_method: str = "nash"
data_paths: DataPaths = field(default_factory=DataPaths) data_paths: DataPaths = field(default_factory=DataPaths)
agent_weights: AgentWeights = field(default_factory=AgentWeights) agent_weights: AgentWeights = field(default_factory=AgentWeights)
force_refresh: bool = False
CONFIG = AppConfig() CONFIG = AppConfig()