This commit is contained in:
sam 2025-09-27 18:03:29 +08:00
parent 9c7a68d313
commit 15a50cad93

View File

@ -6,7 +6,7 @@ import time
from collections import deque
from dataclasses import dataclass
from datetime import date
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple
import pandas as pd
@ -420,6 +420,17 @@ def _range_needs_refresh(
return False
def _existing_suspend_dates(start_str: str, end_str: str, ts_code: str | None = None) -> Set[str]:
sql = "SELECT DISTINCT suspend_date FROM suspend WHERE suspend_date BETWEEN ? AND ?"
params: List[object] = [start_str, end_str]
if ts_code:
sql += " AND ts_code = ?"
params.append(ts_code)
with db_session(read_only=True) as conn:
rows = conn.execute(sql, tuple(params)).fetchall()
return {row["suspend_date"] for row in rows if row["suspend_date"]}
def _listing_window(ts_code: str) -> Tuple[Optional[str], Optional[str]]:
with db_session(read_only=True) as conn:
row = conn.execute(
@ -639,11 +650,27 @@ def fetch_suspensions(
client = _ensure_client()
start_date = _format_date(start)
end_date = _format_date(end)
LOGGER.info("拉取停复牌信息(%s-%s", start_date, end_date, extra=LOG_EXTRA)
LOGGER.info(
"拉取停复牌信息(逐日循环)%s-%s 股票=%s",
start_date,
end_date,
ts_code or "全部",
extra=LOG_EXTRA,
)
trade_dates = _load_trade_dates(start, end)
existing_dates: Set[str] = set()
if skip_existing:
existing_dates = _existing_suspend_dates(start_date, end_date, ts_code)
if existing_dates:
LOGGER.debug(
"停复牌已有覆盖日期数量=%s 示例=%s",
len(existing_dates),
sorted(existing_dates)[:5],
extra=LOG_EXTRA,
)
frames: List[pd.DataFrame] = []
for trade_date in trade_dates:
if skip_existing and _record_exists("suspend", "suspend_date", trade_date, ts_code):
if skip_existing and trade_date in existing_dates:
LOGGER.debug(
"停复牌信息已存在,跳过 %s %s",
ts_code or "ALL",
@ -651,19 +678,30 @@ def fetch_suspensions(
extra=LOG_EXTRA,
)
continue
params = {"trade_date": trade_date}
params: Dict[str, object] = {"trade_date": trade_date}
if ts_code:
params["ts_code"] = ts_code
LOGGER.info("交易日拉取请求endpoint=suspend_d params=%s", params, extra=LOG_EXTRA)
LOGGER.info(
"交易日拉取请求endpoint=suspend_d params=%s",
params,
extra=LOG_EXTRA,
)
df = _fetch_paginated("suspend_d", params, limit=2000)
if not df.empty:
if "suspend_date" not in df.columns and "trade_date" in df.columns:
df = df.rename(columns={"trade_date": "suspend_date"})
frames.append(df)
if not frames:
LOGGER.info("停复牌接口未返回数据", extra=LOG_EXTRA)
return []
merged = pd.concat(frames, ignore_index=True)
return _df_to_records(merged, _TABLE_COLUMNS["suspend"])
missing_cols = [col for col in _TABLE_COLUMNS["suspend"] if col not in merged.columns]
for col in missing_cols:
merged[col] = None
ordered = merged[_TABLE_COLUMNS["suspend"]]
return _df_to_records(ordered, _TABLE_COLUMNS["suspend"])
def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]:
@ -845,8 +883,8 @@ def ensure_data_coverage(
raise
save_records(table, rows)
else:
needs_refresh = force
if not force:
needs_refresh = force or table == "suspend"
if not force and table != "suspend":
expected = expected_days if table in {"daily_basic", "adj_factor", "stk_limit"} else 0
needs_refresh = _range_needs_refresh(table, date_col, start_str, end_str, expected)
if not needs_refresh: