This commit is contained in:
sam 2025-09-27 11:59:18 +08:00
parent 774b68de99
commit 6c9c8e3140
4 changed files with 174 additions and 76 deletions

View File

@ -1,15 +1,16 @@
"""数据覆盖开机检查器。"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import date, timedelta
from typing import Callable, Dict
from app.data.schema import initialize_database
from app.ingest.tushare import collect_data_coverage, ensure_data_coverage
from app.utils.config import get_config
from app.utils.logging import get_logger
LOGGER = logging.getLogger(__name__)
LOGGER = get_logger(__name__)
@dataclass
@ -40,6 +41,7 @@ def run_boot_check(
days: int = 365,
auto_fetch: bool = True,
progress_hook: Callable[[str, float], None] | None = None,
force_refresh: bool | None = None,
) -> CoverageReport:
"""执行开机自检,必要时自动补数据。"""
@ -47,8 +49,17 @@ def run_boot_check(
start, end = _default_window(days)
LOGGER.info("开机检查覆盖窗口:%s%s", start, end)
refresh = force_refresh
if refresh is None:
refresh = get_config().force_refresh
if auto_fetch:
ensure_data_coverage(start, end, progress_hook=progress_hook)
ensure_data_coverage(
start,
end,
force=refresh,
progress_hook=progress_hook,
)
coverage = collect_data_coverage(start, end)

View File

@ -1,7 +1,6 @@
"""TuShare 数据拉取与数据覆盖检查工具。"""
from __future__ import annotations
import logging
import os
from dataclasses import dataclass
from datetime import date
@ -16,9 +15,21 @@ except ImportError: # pragma: no cover - 运行时提示
from app.utils.config import get_config
from app.utils.db import db_session
from app.data.schema import initialize_database
from app.utils.logging import get_logger
def _existing_date_range(table: str, date_col: str, ts_code: str | None = None) -> Tuple[str | None, str | None]:
LOGGER = get_logger(__name__)
API_DEFAULT_LIMIT = 5000
LOG_EXTRA = {"stage": "data_ingest"}
def _existing_date_range(
table: str,
date_col: str,
ts_code: str | None = None,
) -> Tuple[str | None, str | None]:
query = f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d FROM {table}"
params: Tuple = ()
if ts_code:
@ -31,13 +42,11 @@ def _existing_date_range(table: str, date_col: str, ts_code: str | None = None)
return row["min_d"], row["max_d"]
from app.data.schema import initialize_database
LOGGER = logging.getLogger(__name__)
API_DEFAULT_LIMIT = 5000
LOG_EXTRA = {"stage": "data_ingest"}
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
if df is None or df.empty:
return []
reindexed = df.reindex(columns=allowed_cols)
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None = None) -> pd.DataFrame:
@ -294,13 +303,6 @@ def _format_date(value: date) -> str:
return value.strftime("%Y%m%d")
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
if df is None or df.empty:
return []
reindexed = df.reindex(columns=allowed_cols)
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str]:
start_str = _format_date(start)
end_str = _format_date(end)
@ -405,34 +407,66 @@ def fetch_stock_basic(exchange: Optional[str] = None, list_status: str = "L") ->
return _df_to_records(df, _TABLE_COLUMNS["stock_basic"])
def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]:
def fetch_daily_bars(job: FetchJob, skip_existing: bool = True) -> Iterable[Dict]:
client = _ensure_client()
start_date = _format_date(job.start)
end_date = _format_date(job.end)
frames: List[pd.DataFrame] = []
if job.granularity != "daily":
raise ValueError(f"暂不支持的粒度:{job.granularity}")
params = {
"start_date": start_date,
"end_date": end_date,
}
trade_dates = _load_trade_dates(job.start, job.end)
if not trade_dates:
LOGGER.info("本地交易日历缺失,尝试补全后再拉取日线行情", extra=LOG_EXTRA)
ensure_trade_calendar(job.start, job.end)
trade_dates = _load_trade_dates(job.start, job.end)
if job.ts_codes:
for code in job.ts_codes:
LOGGER.info("拉取 %s 的日线行情(%s-%s", code, start_date, end_date)
df = _fetch_paginated("daily", {**params, "ts_code": code})
if not df.empty:
frames.append(df)
for trade_date in trade_dates:
if skip_existing and _record_exists("daily", "trade_date", trade_date, code):
LOGGER.debug(
"日线数据已存在,跳过 %s %s",
code,
trade_date,
extra=LOG_EXTRA,
)
continue
LOGGER.debug(
"按交易日拉取日线行情code=%s trade_date=%s",
code,
trade_date,
extra=LOG_EXTRA,
)
LOGGER.info(
"交易日拉取请求endpoint=daily code=%s trade_date=%s",
code,
trade_date,
extra=LOG_EXTRA,
)
df = _fetch_paginated(
"daily",
{
"trade_date": trade_date,
"ts_code": code,
},
)
if not df.empty:
frames.append(df)
else:
trade_dates = _load_trade_dates(job.start, job.end)
if not trade_dates:
LOGGER.info("本地交易日历缺失,尝试补全后再拉取日线行情")
ensure_trade_calendar(job.start, job.end)
trade_dates = _load_trade_dates(job.start, job.end)
for trade_date in trade_dates:
LOGGER.debug("按交易日拉取日线行情:%s", trade_date)
if skip_existing and _record_exists("daily", "trade_date", trade_date):
LOGGER.debug(
"日线数据已存在,跳过交易日 %s",
trade_date,
extra=LOG_EXTRA,
)
continue
LOGGER.debug("按交易日拉取日线行情:%s", trade_date, extra=LOG_EXTRA)
LOGGER.info(
"交易日拉取请求endpoint=daily trade_date=%s",
trade_date,
extra=LOG_EXTRA,
)
df = _fetch_paginated("daily", {"trade_date": trade_date})
if not df.empty:
frames.append(df)
@ -460,33 +494,25 @@ def fetch_daily_basic(
extra=LOG_EXTRA,
)
if ts_code:
df = _fetch_paginated(
"daily_basic",
{
"ts_code": ts_code,
"start_date": start_date,
"end_date": end_date,
},
)
return _df_to_records(df, _TABLE_COLUMNS["daily_basic"])
trade_dates = _load_trade_dates(start, end)
frames: List[pd.DataFrame] = []
for trade_date in trade_dates:
if skip_existing and _record_exists("daily_basic", "trade_date", trade_date):
if skip_existing and _record_exists("daily_basic", "trade_date", trade_date, ts_code):
LOGGER.info(
"日线基础指标已存在,跳过交易日 %s",
trade_date,
extra=LOG_EXTRA,
)
continue
LOGGER.debug(
"按交易日拉取日线基础指标:%s",
trade_date,
params = {"trade_date": trade_date}
if ts_code:
params["ts_code"] = ts_code
LOGGER.info(
"交易日拉取请求endpoint=daily_basic params=%s",
params,
extra=LOG_EXTRA,
)
df = _fetch_paginated("daily_basic", {"trade_date": trade_date})
df = _fetch_paginated("daily_basic", params)
if not df.empty:
frames.append(df)
@ -528,7 +554,7 @@ def fetch_adj_factor(
params = {"trade_date": trade_date}
if ts_code:
params["ts_code"] = ts_code
LOGGER.debug("按交易日拉取复权因子:%s", params, extra=LOG_EXTRA)
LOGGER.info("交易日拉取请求endpoint=adj_factor params=%s", params, extra=LOG_EXTRA)
df = _fetch_paginated("adj_factor", params)
if not df.empty:
frames.append(df)
@ -540,17 +566,40 @@ def fetch_adj_factor(
return _df_to_records(merged, _TABLE_COLUMNS["adj_factor"])
def fetch_suspensions(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
def fetch_suspensions(
start: date,
end: date,
ts_code: Optional[str] = None,
skip_existing: bool = True,
) -> Iterable[Dict]:
client = _ensure_client()
start_date = _format_date(start)
end_date = _format_date(end)
LOGGER.info("拉取停复牌信息(%s-%s", start_date, end_date)
df = _fetch_paginated("suspend_d", {
"ts_code": ts_code,
"start_date": start_date,
"end_date": end_date,
}, limit=2000)
return _df_to_records(df, _TABLE_COLUMNS["suspend"])
LOGGER.info("拉取停复牌信息(%s-%s", start_date, end_date, extra=LOG_EXTRA)
trade_dates = _load_trade_dates(start, end)
frames: List[pd.DataFrame] = []
for trade_date in trade_dates:
if skip_existing and _record_exists("suspend", "suspend_date", trade_date, ts_code):
LOGGER.debug(
"停复牌信息已存在,跳过 %s %s",
ts_code or "ALL",
trade_date,
extra=LOG_EXTRA,
)
continue
params = {"trade_date": trade_date}
if ts_code:
params["ts_code"] = ts_code
LOGGER.info("交易日拉取请求endpoint=suspend_d params=%s", params, extra=LOG_EXTRA)
df = _fetch_paginated("suspend_d", params, limit=2000)
if not df.empty:
frames.append(df)
if not frames:
return []
merged = pd.concat(frames, ignore_index=True)
return _df_to_records(merged, _TABLE_COLUMNS["suspend"])
def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]:
@ -562,17 +611,40 @@ def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Itera
return _df_to_records(df, _TABLE_COLUMNS["trade_calendar"])
def fetch_stk_limit(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
def fetch_stk_limit(
start: date,
end: date,
ts_code: Optional[str] = None,
skip_existing: bool = True,
) -> Iterable[Dict]:
client = _ensure_client()
start_date = _format_date(start)
end_date = _format_date(end)
LOGGER.info("拉取涨跌停价格(%s-%s", start_date, end_date)
df = _fetch_paginated("stk_limit", {
"ts_code": ts_code,
"start_date": start_date,
"end_date": end_date,
})
return _df_to_records(df, _TABLE_COLUMNS["stk_limit"])
LOGGER.info("拉取涨跌停价格(%s-%s", start_date, end_date, extra=LOG_EXTRA)
trade_dates = _load_trade_dates(start, end)
frames: List[pd.DataFrame] = []
for trade_date in trade_dates:
if skip_existing and _record_exists("stk_limit", "trade_date", trade_date, ts_code):
LOGGER.debug(
"涨跌停数据已存在,跳过 %s %s",
ts_code or "ALL",
trade_date,
extra=LOG_EXTRA,
)
continue
params = {"trade_date": trade_date}
if ts_code:
params["ts_code"] = ts_code
LOGGER.info("交易日拉取请求endpoint=stk_limit params=%s", params, extra=LOG_EXTRA)
df = _fetch_paginated("stk_limit", params)
if not df.empty:
frames.append(df)
if not frames:
return []
merged = pd.concat(frames, ignore_index=True)
return _df_to_records(merged, _TABLE_COLUMNS["stk_limit"])
def save_records(table: str, rows: Iterable[Dict]) -> None:
@ -662,7 +734,7 @@ def ensure_data_coverage(
if pending_codes:
job = FetchJob("daily_autofill", start=start, end=end, ts_codes=tuple(pending_codes))
LOGGER.info("开始拉取日线行情:%s-%s(待补股票 %d 支)", start_str, end_str, len(pending_codes))
save_records("daily", fetch_daily_bars(job))
save_records("daily", fetch_daily_bars(job, skip_existing=not force))
else:
needs_daily = force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days)
if not needs_daily:
@ -670,7 +742,7 @@ def ensure_data_coverage(
else:
job = FetchJob("daily_autofill", start=start, end=end)
LOGGER.info("开始拉取日线行情:%s-%s", start_str, end_str)
save_records("daily", fetch_daily_bars(job))
save_records("daily", fetch_daily_bars(job, skip_existing=not force))
date_cols = {
"daily_basic": "trade_date",
@ -689,7 +761,7 @@ def ensure_data_coverage(
LOGGER.info("拉取 %s 表数据(股票:%s%s-%s", table, code, start_str, end_str)
try:
kwargs = {"ts_code": code}
if fetch_fn in (fetch_daily_basic, fetch_adj_factor):
if fetch_fn in (fetch_daily_basic, fetch_adj_factor, fetch_suspensions, fetch_stk_limit):
kwargs["skip_existing"] = not force
rows = fetch_fn(start, end, **kwargs)
except Exception:
@ -707,7 +779,7 @@ def ensure_data_coverage(
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
try:
kwargs = {}
if fetch_fn in (fetch_daily_basic, fetch_adj_factor):
if fetch_fn in (fetch_daily_basic, fetch_adj_factor, fetch_suspensions, fetch_stk_limit):
kwargs["skip_existing"] = not force
rows = fetch_fn(start, end, **kwargs)
except Exception:

View File

@ -114,9 +114,10 @@ def render_settings() -> None:
st.header("数据与设置")
cfg = get_config()
token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password")
if st.button("保存 Token"):
if st.button("保存设置"):
cfg.tushare_token = token.strip() or None
st.success("TuShare Token 已更新,仅保存在当前会话")
st.success("设置已保存,仅在当前会话生效")
st.write("新闻源开关与数据库备份将在此配置。")
@ -155,6 +156,15 @@ def render_tests() -> None:
st.divider()
days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30))
cfg = get_config()
force_refresh = st.checkbox(
"强制刷新数据(关闭增量跳过)",
value=cfg.force_refresh,
help="勾选后将重新拉取所选区间全部数据",
)
if force_refresh != cfg.force_refresh:
cfg.force_refresh = force_refresh
if st.button("执行开机检查"):
progress_bar = st.progress(0.0)
status_placeholder = st.empty()
@ -168,7 +178,11 @@ def render_tests() -> None:
with st.spinner("正在执行开机检查..."):
try:
report = run_boot_check(days=days, progress_hook=hook)
report = run_boot_check(
days=days,
progress_hook=hook,
force_refresh=force_refresh,
)
st.success("开机检查完成,以下为数据覆盖摘要。")
st.json(report.to_dict())
if messages:

View File

@ -54,6 +54,7 @@ class AppConfig:
decision_method: str = "nash"
data_paths: DataPaths = field(default_factory=DataPaths)
agent_weights: AgentWeights = field(default_factory=AgentWeights)
force_refresh: bool = False
CONFIG = AppConfig()