From 15a50cad93fb857815356918ef8c727a0eb027a6 Mon Sep 17 00:00:00 2001 From: sam Date: Sat, 27 Sep 2025 18:03:29 +0800 Subject: [PATCH] update --- app/ingest/tushare.py | 54 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/app/ingest/tushare.py b/app/ingest/tushare.py index 55a0274..814b729 100644 --- a/app/ingest/tushare.py +++ b/app/ingest/tushare.py @@ -6,7 +6,7 @@ import time from collections import deque from dataclasses import dataclass from datetime import date -from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple import pandas as pd @@ -420,6 +420,17 @@ def _range_needs_refresh( return False +def _existing_suspend_dates(start_str: str, end_str: str, ts_code: str | None = None) -> Set[str]: + sql = "SELECT DISTINCT suspend_date FROM suspend WHERE suspend_date BETWEEN ? AND ?" + params: List[object] = [start_str, end_str] + if ts_code: + sql += " AND ts_code = ?" + params.append(ts_code) + with db_session(read_only=True) as conn: + rows = conn.execute(sql, tuple(params)).fetchall() + return {row["suspend_date"] for row in rows if row["suspend_date"]} + + def _listing_window(ts_code: str) -> Tuple[Optional[str], Optional[str]]: with db_session(read_only=True) as conn: row = conn.execute( @@ -639,11 +650,27 @@ def fetch_suspensions( client = _ensure_client() start_date = _format_date(start) end_date = _format_date(end) - LOGGER.info("拉取停复牌信息(%s-%s)", start_date, end_date, extra=LOG_EXTRA) + LOGGER.info( + "拉取停复牌信息(逐日循环)%s-%s 股票=%s", + start_date, + end_date, + ts_code or "全部", + extra=LOG_EXTRA, + ) trade_dates = _load_trade_dates(start, end) + existing_dates: Set[str] = set() + if skip_existing: + existing_dates = _existing_suspend_dates(start_date, end_date, ts_code) + if existing_dates: + LOGGER.debug( + "停复牌已有覆盖日期数量=%s 示例=%s", + len(existing_dates), + sorted(existing_dates)[:5], + extra=LOG_EXTRA, + ) frames: List[pd.DataFrame] = [] for trade_date in trade_dates: - if skip_existing and _record_exists("suspend", "suspend_date", trade_date, ts_code): + if skip_existing and trade_date in existing_dates: LOGGER.debug( "停复牌信息已存在,跳过 %s %s", ts_code or "ALL", @@ -651,19 +678,30 @@ def fetch_suspensions( extra=LOG_EXTRA, ) continue - params = {"trade_date": trade_date} + params: Dict[str, object] = {"trade_date": trade_date} if ts_code: params["ts_code"] = ts_code - LOGGER.info("交易日拉取请求:endpoint=suspend_d params=%s", params, extra=LOG_EXTRA) + LOGGER.info( + "交易日拉取请求:endpoint=suspend_d params=%s", + params, + extra=LOG_EXTRA, + ) df = _fetch_paginated("suspend_d", params, limit=2000) if not df.empty: + if "suspend_date" not in df.columns and "trade_date" in df.columns: + df = df.rename(columns={"trade_date": "suspend_date"}) frames.append(df) if not frames: + LOGGER.info("停复牌接口未返回数据", extra=LOG_EXTRA) return [] merged = pd.concat(frames, ignore_index=True) - return _df_to_records(merged, _TABLE_COLUMNS["suspend"]) + missing_cols = [col for col in _TABLE_COLUMNS["suspend"] if col not in merged.columns] + for col in missing_cols: + merged[col] = None + ordered = merged[_TABLE_COLUMNS["suspend"]] + return _df_to_records(ordered, _TABLE_COLUMNS["suspend"]) def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]: @@ -845,8 +883,8 @@ def ensure_data_coverage( raise save_records(table, rows) else: - needs_refresh = force - if not force: + needs_refresh = force or table == "suspend" + if not force and table != "suspend": expected = expected_days if table in {"daily_basic", "adj_factor", "stk_limit"} else 0 needs_refresh = _range_needs_refresh(table, date_col, start_str, end_str, expected) if not needs_refresh: