From 774b68de9937cbf0e5c83b8ce5232b42b811008a Mon Sep 17 00:00:00 2001 From: sam Date: Sat, 27 Sep 2025 08:48:00 +0800 Subject: [PATCH] update --- app/ingest/tushare.py | 61 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 13 deletions(-) diff --git a/app/ingest/tushare.py b/app/ingest/tushare.py index 2c897d5..e6c5981 100644 --- a/app/ingest/tushare.py +++ b/app/ingest/tushare.py @@ -313,8 +313,13 @@ def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str return [row["cal_date"] for row in rows] -def _daily_basic_exists(trade_date: str, ts_code: Optional[str] = None) -> bool: - query = "SELECT 1 FROM daily_basic WHERE trade_date = ?" +def _record_exists( + table: str, + date_col: str, + trade_date: str, + ts_code: Optional[str] = None, +) -> bool: + query = f"SELECT 1 FROM {table} WHERE {date_col} = ?" params: Tuple = (trade_date,) if ts_code: query += " AND ts_code = ?" @@ -469,7 +474,7 @@ def fetch_daily_basic( trade_dates = _load_trade_dates(start, end) frames: List[pd.DataFrame] = [] for trade_date in trade_dates: - if skip_existing and _daily_basic_exists(trade_date): + if skip_existing and _record_exists("daily_basic", "trade_date", trade_date): LOGGER.info( "日线基础指标已存在,跳过交易日 %s", trade_date, @@ -492,17 +497,47 @@ def fetch_daily_basic( return _df_to_records(merged, _TABLE_COLUMNS["daily_basic"]) -def fetch_adj_factor(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]: +def fetch_adj_factor( + start: date, + end: date, + ts_code: Optional[str] = None, + skip_existing: bool = True, +) -> Iterable[Dict]: client = _ensure_client() start_date = _format_date(start) end_date = _format_date(end) - LOGGER.info("拉取复权因子(%s-%s,股票:%s)", start_date, end_date, ts_code or "全部") - df = _fetch_paginated("adj_factor", { - "ts_code": ts_code, - "start_date": start_date, - "end_date": end_date, - }) - return _df_to_records(df, _TABLE_COLUMNS["adj_factor"]) + LOGGER.info( + "拉取复权因子(%s-%s,股票:%s)", + start_date, + end_date, + ts_code or "全部", + extra=LOG_EXTRA, + ) + + trade_dates = _load_trade_dates(start, end) + frames: List[pd.DataFrame] = [] + for trade_date in trade_dates: + if skip_existing and _record_exists("adj_factor", "trade_date", trade_date, ts_code): + LOGGER.debug( + "复权因子已存在,跳过 %s %s", + ts_code or "ALL", + trade_date, + extra=LOG_EXTRA, + ) + continue + params = {"trade_date": trade_date} + if ts_code: + params["ts_code"] = ts_code + LOGGER.debug("按交易日拉取复权因子:%s", params, extra=LOG_EXTRA) + df = _fetch_paginated("adj_factor", params) + if not df.empty: + frames.append(df) + + if not frames: + return [] + + merged = pd.concat(frames, ignore_index=True) + return _df_to_records(merged, _TABLE_COLUMNS["adj_factor"]) def fetch_suspensions(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]: @@ -654,7 +689,7 @@ def ensure_data_coverage( LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str) try: kwargs = {"ts_code": code} - if fetch_fn is fetch_daily_basic: + if fetch_fn in (fetch_daily_basic, fetch_adj_factor): kwargs["skip_existing"] = not force rows = fetch_fn(start, end, **kwargs) except Exception: @@ -672,7 +707,7 @@ def ensure_data_coverage( LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str) try: kwargs = {} - if fetch_fn is fetch_daily_basic: + if fetch_fn in (fetch_daily_basic, fetch_adj_factor): kwargs["skip_existing"] = not force rows = fetch_fn(start, end, **kwargs) except Exception: