This commit is contained in:
sam 2025-09-27 19:29:00 +08:00
parent b9a359d501
commit 63dd380a70

View File

@ -27,9 +27,14 @@ LOGGER = get_logger(__name__)
API_DEFAULT_LIMIT = 5000
LOG_EXTRA = {"stage": "data_ingest"}
_CALL_QUEUE = deque()
_CALL_BUCKETS: Dict[str, deque] = defaultdict(deque)
RATE_LIMIT_ERROR_PATTERNS: Tuple[str, ...] = (
"最多访问该接口",
"超过接口限制",
"Frequency limit",
)
API_RATE_LIMITS: Dict[str, int] = {
"stock_basic": 180,
"daily": 480,
@ -115,7 +120,7 @@ def _normalize_date_str(value: Optional[str]) -> Optional[str]:
return text or None
def _respect_rate_limit(endpoint: str | None, cfg) -> None:
def _respect_rate_limit(endpoint: str | None) -> None:
def _throttle(queue: deque, limit: int) -> None:
if limit <= 0:
return
@ -135,12 +140,8 @@ def _respect_rate_limit(endpoint: str | None, cfg) -> None:
time.sleep(max(0.1, sleep_time))
queue.append(time.time())
max_calls = cfg.max_calls_per_minute
if max_calls > 0:
_throttle(_CALL_QUEUE, max_calls)
bucket_key = endpoint or "_default"
endpoint_limit = API_RATE_LIMITS.get(bucket_key, max_calls)
endpoint_limit = API_RATE_LIMITS.get(bucket_key, 60)
_throttle(_CALL_BUCKETS[bucket_key], endpoint_limit or 0)
@ -165,11 +166,26 @@ def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None
extra=LOG_EXTRA,
)
while True:
_respect_rate_limit(endpoint, get_config())
_respect_rate_limit(endpoint)
call = getattr(client, endpoint)
try:
df = call(limit=limit, offset=offset, **clean_params)
except Exception: # noqa: BLE001
except Exception as exc: # noqa: BLE001
message = str(exc)
if any(pattern in message for pattern in RATE_LIMIT_ERROR_PATTERNS):
per_minute = API_RATE_LIMITS.get(endpoint or "", 0)
wait_time = 60.0 / per_minute + 1 if per_minute else 30.0
wait_time = max(wait_time, 30.0)
LOGGER.warning(
"接口限频触发:%s,原因=%s,等待 %.1f 秒后重试",
endpoint,
message,
wait_time,
extra=LOG_EXTRA,
)
time.sleep(wait_time)
continue
LOGGER.exception(
"TuShare 接口调用异常endpoint=%s offset=%s params=%s",
endpoint,
@ -874,7 +890,7 @@ def fetch_stock_basic(exchange: Optional[str] = None, list_status: str = "L") ->
list_status,
extra=LOG_EXTRA,
)
_respect_rate_limit("stock_basic", get_config())
_respect_rate_limit("stock_basic")
fields = "ts_code,symbol,name,area,industry,market,exchange,list_status,list_date,delist_date"
df = client.stock_basic(exchange=exchange, list_status=list_status, fields=fields)
return _df_to_records(df, _TABLE_COLUMNS["stock_basic"])
@ -1113,7 +1129,7 @@ def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Itera
end_date,
extra=LOG_EXTRA,
)
_respect_rate_limit("trade_cal", get_config())
_respect_rate_limit("trade_cal")
df = client.trade_cal(exchange=exchange, start_date=start_date, end_date=end_date)
if df is not None and not df.empty and "is_open" in df.columns:
df["is_open"] = pd.to_numeric(df["is_open"], errors="coerce").fillna(0).astype(int)
@ -1159,7 +1175,7 @@ def fetch_stk_limit(
def fetch_index_basic(market: Optional[str] = None) -> Iterable[Dict]:
client = _ensure_client()
LOGGER.info("拉取指数基础信息market=%s", market or "all", extra=LOG_EXTRA)
_respect_rate_limit("index_basic", get_config())
_respect_rate_limit("index_basic")
df = client.index_basic(market=market)
return _df_to_records(df, _TABLE_COLUMNS["index_basic"])
@ -1186,7 +1202,7 @@ def fetch_index_daily(start: date, end: date, ts_code: str) -> Iterable[Dict]:
def fetch_fund_basic(asset_class: str = "E", status: str = "L") -> Iterable[Dict]:
client = _ensure_client()
LOGGER.info("拉取基金基础信息asset_class=%s status=%s", asset_class, status, extra=LOG_EXTRA)
_respect_rate_limit("fund_basic", get_config())
_respect_rate_limit("fund_basic")
df = client.fund_basic(market=asset_class, status=status)
return _df_to_records(df, _TABLE_COLUMNS["fund_basic"])
@ -1213,7 +1229,7 @@ def fetch_fund_nav(start: date, end: date, ts_code: str) -> Iterable[Dict]:
def fetch_fut_basic(exchange: Optional[str] = None) -> Iterable[Dict]:
client = _ensure_client()
LOGGER.info("拉取期货基础信息exchange=%s", exchange or "all", extra=LOG_EXTRA)
_respect_rate_limit("fut_basic", get_config())
_respect_rate_limit("fut_basic")
df = client.fut_basic(exchange=exchange)
return _df_to_records(df, _TABLE_COLUMNS["fut_basic"])