From 6c9c8e3140fae534892644acdc3295d834eea13f Mon Sep 17 00:00:00 2001 From: sam Date: Sat, 27 Sep 2025 11:59:18 +0800 Subject: [PATCH] update --- app/ingest/checker.py | 17 +++- app/ingest/tushare.py | 212 +++++++++++++++++++++++++++------------- app/ui/streamlit_app.py | 20 +++- app/utils/config.py | 1 + 4 files changed, 174 insertions(+), 76 deletions(-) diff --git a/app/ingest/checker.py b/app/ingest/checker.py index b160e08..0aa9130 100644 --- a/app/ingest/checker.py +++ b/app/ingest/checker.py @@ -1,15 +1,16 @@ """数据覆盖开机检查器。""" from __future__ import annotations -import logging from dataclasses import dataclass from datetime import date, timedelta from typing import Callable, Dict from app.data.schema import initialize_database from app.ingest.tushare import collect_data_coverage, ensure_data_coverage +from app.utils.config import get_config +from app.utils.logging import get_logger -LOGGER = logging.getLogger(__name__) +LOGGER = get_logger(__name__) @dataclass @@ -40,6 +41,7 @@ def run_boot_check( days: int = 365, auto_fetch: bool = True, progress_hook: Callable[[str, float], None] | None = None, + force_refresh: bool | None = None, ) -> CoverageReport: """执行开机自检,必要时自动补数据。""" @@ -47,8 +49,17 @@ def run_boot_check( start, end = _default_window(days) LOGGER.info("开机检查覆盖窗口:%s 至 %s", start, end) + refresh = force_refresh + if refresh is None: + refresh = get_config().force_refresh + if auto_fetch: - ensure_data_coverage(start, end, progress_hook=progress_hook) + ensure_data_coverage( + start, + end, + force=refresh, + progress_hook=progress_hook, + ) coverage = collect_data_coverage(start, end) diff --git a/app/ingest/tushare.py b/app/ingest/tushare.py index e6c5981..974ae2b 100644 --- a/app/ingest/tushare.py +++ b/app/ingest/tushare.py @@ -1,7 +1,6 @@ """TuShare 数据拉取与数据覆盖检查工具。""" from __future__ import annotations -import logging import os from dataclasses import dataclass from datetime import date @@ -16,9 +15,21 @@ except ImportError: # pragma: no cover - 运行时提示 from app.utils.config import get_config from app.utils.db import db_session +from app.data.schema import initialize_database +from app.utils.logging import get_logger -def _existing_date_range(table: str, date_col: str, ts_code: str | None = None) -> Tuple[str | None, str | None]: +LOGGER = get_logger(__name__) + +API_DEFAULT_LIMIT = 5000 +LOG_EXTRA = {"stage": "data_ingest"} + + +def _existing_date_range( + table: str, + date_col: str, + ts_code: str | None = None, +) -> Tuple[str | None, str | None]: query = f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d FROM {table}" params: Tuple = () if ts_code: @@ -31,13 +42,11 @@ def _existing_date_range(table: str, date_col: str, ts_code: str | None = None) return row["min_d"], row["max_d"] - -from app.data.schema import initialize_database - -LOGGER = logging.getLogger(__name__) - -API_DEFAULT_LIMIT = 5000 -LOG_EXTRA = {"stage": "data_ingest"} +def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]: + if df is None or df.empty: + return [] + reindexed = df.reindex(columns=allowed_cols) + return reindexed.where(pd.notnull(reindexed), None).to_dict("records") def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None = None) -> pd.DataFrame: @@ -294,13 +303,6 @@ def _format_date(value: date) -> str: return value.strftime("%Y%m%d") -def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]: - if df is None or df.empty: - return [] - reindexed = df.reindex(columns=allowed_cols) - return reindexed.where(pd.notnull(reindexed), None).to_dict("records") - - def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str]: start_str = _format_date(start) end_str = _format_date(end) @@ -405,34 +407,66 @@ def fetch_stock_basic(exchange: Optional[str] = None, list_status: str = "L") -> return _df_to_records(df, _TABLE_COLUMNS["stock_basic"]) -def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]: +def fetch_daily_bars(job: FetchJob, skip_existing: bool = True) -> Iterable[Dict]: client = _ensure_client() - start_date = _format_date(job.start) - end_date = _format_date(job.end) frames: List[pd.DataFrame] = [] if job.granularity != "daily": raise ValueError(f"暂不支持的粒度:{job.granularity}") - params = { - "start_date": start_date, - "end_date": end_date, - } + trade_dates = _load_trade_dates(job.start, job.end) + if not trade_dates: + LOGGER.info("本地交易日历缺失,尝试补全后再拉取日线行情", extra=LOG_EXTRA) + ensure_trade_calendar(job.start, job.end) + trade_dates = _load_trade_dates(job.start, job.end) if job.ts_codes: for code in job.ts_codes: - LOGGER.info("拉取 %s 的日线行情(%s-%s)", code, start_date, end_date) - df = _fetch_paginated("daily", {**params, "ts_code": code}) - if not df.empty: - frames.append(df) + for trade_date in trade_dates: + if skip_existing and _record_exists("daily", "trade_date", trade_date, code): + LOGGER.debug( + "日线数据已存在,跳过 %s %s", + code, + trade_date, + extra=LOG_EXTRA, + ) + continue + LOGGER.debug( + "按交易日拉取日线行情:code=%s trade_date=%s", + code, + trade_date, + extra=LOG_EXTRA, + ) + LOGGER.info( + "交易日拉取请求:endpoint=daily code=%s trade_date=%s", + code, + trade_date, + extra=LOG_EXTRA, + ) + df = _fetch_paginated( + "daily", + { + "trade_date": trade_date, + "ts_code": code, + }, + ) + if not df.empty: + frames.append(df) else: - trade_dates = _load_trade_dates(job.start, job.end) - if not trade_dates: - LOGGER.info("本地交易日历缺失,尝试补全后再拉取日线行情") - ensure_trade_calendar(job.start, job.end) - trade_dates = _load_trade_dates(job.start, job.end) for trade_date in trade_dates: - LOGGER.debug("按交易日拉取日线行情:%s", trade_date) + if skip_existing and _record_exists("daily", "trade_date", trade_date): + LOGGER.debug( + "日线数据已存在,跳过交易日 %s", + trade_date, + extra=LOG_EXTRA, + ) + continue + LOGGER.debug("按交易日拉取日线行情:%s", trade_date, extra=LOG_EXTRA) + LOGGER.info( + "交易日拉取请求:endpoint=daily trade_date=%s", + trade_date, + extra=LOG_EXTRA, + ) df = _fetch_paginated("daily", {"trade_date": trade_date}) if not df.empty: frames.append(df) @@ -460,33 +494,25 @@ def fetch_daily_basic( extra=LOG_EXTRA, ) - if ts_code: - df = _fetch_paginated( - "daily_basic", - { - "ts_code": ts_code, - "start_date": start_date, - "end_date": end_date, - }, - ) - return _df_to_records(df, _TABLE_COLUMNS["daily_basic"]) - trade_dates = _load_trade_dates(start, end) frames: List[pd.DataFrame] = [] for trade_date in trade_dates: - if skip_existing and _record_exists("daily_basic", "trade_date", trade_date): + if skip_existing and _record_exists("daily_basic", "trade_date", trade_date, ts_code): LOGGER.info( "日线基础指标已存在,跳过交易日 %s", trade_date, extra=LOG_EXTRA, ) continue - LOGGER.debug( - "按交易日拉取日线基础指标:%s", - trade_date, + params = {"trade_date": trade_date} + if ts_code: + params["ts_code"] = ts_code + LOGGER.info( + "交易日拉取请求:endpoint=daily_basic params=%s", + params, extra=LOG_EXTRA, ) - df = _fetch_paginated("daily_basic", {"trade_date": trade_date}) + df = _fetch_paginated("daily_basic", params) if not df.empty: frames.append(df) @@ -528,7 +554,7 @@ def fetch_adj_factor( params = {"trade_date": trade_date} if ts_code: params["ts_code"] = ts_code - LOGGER.debug("按交易日拉取复权因子:%s", params, extra=LOG_EXTRA) + LOGGER.info("交易日拉取请求:endpoint=adj_factor params=%s", params, extra=LOG_EXTRA) df = _fetch_paginated("adj_factor", params) if not df.empty: frames.append(df) @@ -540,17 +566,40 @@ def fetch_adj_factor( return _df_to_records(merged, _TABLE_COLUMNS["adj_factor"]) -def fetch_suspensions(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]: +def fetch_suspensions( + start: date, + end: date, + ts_code: Optional[str] = None, + skip_existing: bool = True, +) -> Iterable[Dict]: client = _ensure_client() start_date = _format_date(start) end_date = _format_date(end) - LOGGER.info("拉取停复牌信息(%s-%s)", start_date, end_date) - df = _fetch_paginated("suspend_d", { - "ts_code": ts_code, - "start_date": start_date, - "end_date": end_date, - }, limit=2000) - return _df_to_records(df, _TABLE_COLUMNS["suspend"]) + LOGGER.info("拉取停复牌信息(%s-%s)", start_date, end_date, extra=LOG_EXTRA) + trade_dates = _load_trade_dates(start, end) + frames: List[pd.DataFrame] = [] + for trade_date in trade_dates: + if skip_existing and _record_exists("suspend", "suspend_date", trade_date, ts_code): + LOGGER.debug( + "停复牌信息已存在,跳过 %s %s", + ts_code or "ALL", + trade_date, + extra=LOG_EXTRA, + ) + continue + params = {"trade_date": trade_date} + if ts_code: + params["ts_code"] = ts_code + LOGGER.info("交易日拉取请求:endpoint=suspend_d params=%s", params, extra=LOG_EXTRA) + df = _fetch_paginated("suspend_d", params, limit=2000) + if not df.empty: + frames.append(df) + + if not frames: + return [] + + merged = pd.concat(frames, ignore_index=True) + return _df_to_records(merged, _TABLE_COLUMNS["suspend"]) def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]: @@ -562,17 +611,40 @@ def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Itera return _df_to_records(df, _TABLE_COLUMNS["trade_calendar"]) -def fetch_stk_limit(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]: +def fetch_stk_limit( + start: date, + end: date, + ts_code: Optional[str] = None, + skip_existing: bool = True, +) -> Iterable[Dict]: client = _ensure_client() start_date = _format_date(start) end_date = _format_date(end) - LOGGER.info("拉取涨跌停价格(%s-%s)", start_date, end_date) - df = _fetch_paginated("stk_limit", { - "ts_code": ts_code, - "start_date": start_date, - "end_date": end_date, - }) - return _df_to_records(df, _TABLE_COLUMNS["stk_limit"]) + LOGGER.info("拉取涨跌停价格(%s-%s)", start_date, end_date, extra=LOG_EXTRA) + trade_dates = _load_trade_dates(start, end) + frames: List[pd.DataFrame] = [] + for trade_date in trade_dates: + if skip_existing and _record_exists("stk_limit", "trade_date", trade_date, ts_code): + LOGGER.debug( + "涨跌停数据已存在,跳过 %s %s", + ts_code or "ALL", + trade_date, + extra=LOG_EXTRA, + ) + continue + params = {"trade_date": trade_date} + if ts_code: + params["ts_code"] = ts_code + LOGGER.info("交易日拉取请求:endpoint=stk_limit params=%s", params, extra=LOG_EXTRA) + df = _fetch_paginated("stk_limit", params) + if not df.empty: + frames.append(df) + + if not frames: + return [] + + merged = pd.concat(frames, ignore_index=True) + return _df_to_records(merged, _TABLE_COLUMNS["stk_limit"]) def save_records(table: str, rows: Iterable[Dict]) -> None: @@ -662,7 +734,7 @@ def ensure_data_coverage( if pending_codes: job = FetchJob("daily_autofill", start=start, end=end, ts_codes=tuple(pending_codes)) LOGGER.info("开始拉取日线行情:%s-%s(待补股票 %d 支)", start_str, end_str, len(pending_codes)) - save_records("daily", fetch_daily_bars(job)) + save_records("daily", fetch_daily_bars(job, skip_existing=not force)) else: needs_daily = force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days) if not needs_daily: @@ -670,7 +742,7 @@ def ensure_data_coverage( else: job = FetchJob("daily_autofill", start=start, end=end) LOGGER.info("开始拉取日线行情:%s-%s", start_str, end_str) - save_records("daily", fetch_daily_bars(job)) + save_records("daily", fetch_daily_bars(job, skip_existing=not force)) date_cols = { "daily_basic": "trade_date", @@ -689,7 +761,7 @@ def ensure_data_coverage( LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str) try: kwargs = {"ts_code": code} - if fetch_fn in (fetch_daily_basic, fetch_adj_factor): + if fetch_fn in (fetch_daily_basic, fetch_adj_factor, fetch_suspensions, fetch_stk_limit): kwargs["skip_existing"] = not force rows = fetch_fn(start, end, **kwargs) except Exception: @@ -707,7 +779,7 @@ def ensure_data_coverage( LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str) try: kwargs = {} - if fetch_fn in (fetch_daily_basic, fetch_adj_factor): + if fetch_fn in (fetch_daily_basic, fetch_adj_factor, fetch_suspensions, fetch_stk_limit): kwargs["skip_existing"] = not force rows = fetch_fn(start, end, **kwargs) except Exception: diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 79f5140..99b127a 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -114,9 +114,10 @@ def render_settings() -> None: st.header("数据与设置") cfg = get_config() token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password") - if st.button("保存 Token"): + + if st.button("保存设置"): cfg.tushare_token = token.strip() or None - st.success("TuShare Token 已更新,仅保存在当前会话。") + st.success("设置已保存,仅在当前会话生效。") st.write("新闻源开关与数据库备份将在此配置。") @@ -155,6 +156,15 @@ def render_tests() -> None: st.divider() days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30)) + cfg = get_config() + force_refresh = st.checkbox( + "强制刷新数据(关闭增量跳过)", + value=cfg.force_refresh, + help="勾选后将重新拉取所选区间全部数据", + ) + if force_refresh != cfg.force_refresh: + cfg.force_refresh = force_refresh + if st.button("执行开机检查"): progress_bar = st.progress(0.0) status_placeholder = st.empty() @@ -168,7 +178,11 @@ def render_tests() -> None: with st.spinner("正在执行开机检查..."): try: - report = run_boot_check(days=days, progress_hook=hook) + report = run_boot_check( + days=days, + progress_hook=hook, + force_refresh=force_refresh, + ) st.success("开机检查完成,以下为数据覆盖摘要。") st.json(report.to_dict()) if messages: diff --git a/app/utils/config.py b/app/utils/config.py index fb0a128..007517a 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -54,6 +54,7 @@ class AppConfig: decision_method: str = "nash" data_paths: DataPaths = field(default_factory=DataPaths) agent_weights: AgentWeights = field(default_factory=AgentWeights) + force_refresh: bool = False CONFIG = AppConfig()