diff --git a/app/ingest/tushare.py b/app/ingest/tushare.py index 20612ad..ee06f85 100644 --- a/app/ingest/tushare.py +++ b/app/ingest/tushare.py @@ -27,9 +27,14 @@ LOGGER = get_logger(__name__) API_DEFAULT_LIMIT = 5000 LOG_EXTRA = {"stage": "data_ingest"} -_CALL_QUEUE = deque() _CALL_BUCKETS: Dict[str, deque] = defaultdict(deque) +RATE_LIMIT_ERROR_PATTERNS: Tuple[str, ...] = ( + "最多访问该接口", + "超过接口限制", + "Frequency limit", +) + API_RATE_LIMITS: Dict[str, int] = { "stock_basic": 180, "daily": 480, @@ -115,7 +120,7 @@ def _normalize_date_str(value: Optional[str]) -> Optional[str]: return text or None -def _respect_rate_limit(endpoint: str | None, cfg) -> None: +def _respect_rate_limit(endpoint: str | None) -> None: def _throttle(queue: deque, limit: int) -> None: if limit <= 0: return @@ -135,12 +140,8 @@ def _respect_rate_limit(endpoint: str | None, cfg) -> None: time.sleep(max(0.1, sleep_time)) queue.append(time.time()) - max_calls = cfg.max_calls_per_minute - if max_calls > 0: - _throttle(_CALL_QUEUE, max_calls) - bucket_key = endpoint or "_default" - endpoint_limit = API_RATE_LIMITS.get(bucket_key, max_calls) + endpoint_limit = API_RATE_LIMITS.get(bucket_key, 60) _throttle(_CALL_BUCKETS[bucket_key], endpoint_limit or 0) @@ -165,11 +166,26 @@ def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None extra=LOG_EXTRA, ) while True: - _respect_rate_limit(endpoint, get_config()) + _respect_rate_limit(endpoint) call = getattr(client, endpoint) try: df = call(limit=limit, offset=offset, **clean_params) - except Exception: # noqa: BLE001 + except Exception as exc: # noqa: BLE001 + message = str(exc) + if any(pattern in message for pattern in RATE_LIMIT_ERROR_PATTERNS): + per_minute = API_RATE_LIMITS.get(endpoint or "", 0) + wait_time = 60.0 / per_minute + 1 if per_minute else 30.0 + wait_time = max(wait_time, 30.0) + LOGGER.warning( + "接口限频触发:%s,原因=%s,等待 %.1f 秒后重试", + endpoint, + message, + wait_time, + extra=LOG_EXTRA, + ) + time.sleep(wait_time) + continue + LOGGER.exception( "TuShare 接口调用异常:endpoint=%s offset=%s params=%s", endpoint, @@ -874,7 +890,7 @@ def fetch_stock_basic(exchange: Optional[str] = None, list_status: str = "L") -> list_status, extra=LOG_EXTRA, ) - _respect_rate_limit("stock_basic", get_config()) + _respect_rate_limit("stock_basic") fields = "ts_code,symbol,name,area,industry,market,exchange,list_status,list_date,delist_date" df = client.stock_basic(exchange=exchange, list_status=list_status, fields=fields) return _df_to_records(df, _TABLE_COLUMNS["stock_basic"]) @@ -1113,7 +1129,7 @@ def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Itera end_date, extra=LOG_EXTRA, ) - _respect_rate_limit("trade_cal", get_config()) + _respect_rate_limit("trade_cal") df = client.trade_cal(exchange=exchange, start_date=start_date, end_date=end_date) if df is not None and not df.empty and "is_open" in df.columns: df["is_open"] = pd.to_numeric(df["is_open"], errors="coerce").fillna(0).astype(int) @@ -1159,7 +1175,7 @@ def fetch_stk_limit( def fetch_index_basic(market: Optional[str] = None) -> Iterable[Dict]: client = _ensure_client() LOGGER.info("拉取指数基础信息(market=%s)", market or "all", extra=LOG_EXTRA) - _respect_rate_limit("index_basic", get_config()) + _respect_rate_limit("index_basic") df = client.index_basic(market=market) return _df_to_records(df, _TABLE_COLUMNS["index_basic"]) @@ -1186,7 +1202,7 @@ def fetch_index_daily(start: date, end: date, ts_code: str) -> Iterable[Dict]: def fetch_fund_basic(asset_class: str = "E", status: str = "L") -> Iterable[Dict]: client = _ensure_client() LOGGER.info("拉取基金基础信息:asset_class=%s status=%s", asset_class, status, extra=LOG_EXTRA) - _respect_rate_limit("fund_basic", get_config()) + _respect_rate_limit("fund_basic") df = client.fund_basic(market=asset_class, status=status) return _df_to_records(df, _TABLE_COLUMNS["fund_basic"]) @@ -1213,7 +1229,7 @@ def fetch_fund_nav(start: date, end: date, ts_code: str) -> Iterable[Dict]: def fetch_fut_basic(exchange: Optional[str] = None) -> Iterable[Dict]: client = _ensure_client() LOGGER.info("拉取期货基础信息(exchange=%s)", exchange or "all", extra=LOG_EXTRA) - _respect_rate_limit("fut_basic", get_config()) + _respect_rate_limit("fut_basic") df = client.fut_basic(exchange=exchange) return _df_to_records(df, _TABLE_COLUMNS["fut_basic"])