This commit is contained in:
sam 2025-09-27 08:48:00 +08:00
parent 5ef90b8de0
commit 774b68de99

View File

@ -313,8 +313,13 @@ def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str
return [row["cal_date"] for row in rows]
def _daily_basic_exists(trade_date: str, ts_code: Optional[str] = None) -> bool:
query = "SELECT 1 FROM daily_basic WHERE trade_date = ?"
def _record_exists(
table: str,
date_col: str,
trade_date: str,
ts_code: Optional[str] = None,
) -> bool:
query = f"SELECT 1 FROM {table} WHERE {date_col} = ?"
params: Tuple = (trade_date,)
if ts_code:
query += " AND ts_code = ?"
@ -469,7 +474,7 @@ def fetch_daily_basic(
trade_dates = _load_trade_dates(start, end)
frames: List[pd.DataFrame] = []
for trade_date in trade_dates:
if skip_existing and _daily_basic_exists(trade_date):
if skip_existing and _record_exists("daily_basic", "trade_date", trade_date):
LOGGER.info(
"日线基础指标已存在,跳过交易日 %s",
trade_date,
@ -492,17 +497,47 @@ def fetch_daily_basic(
return _df_to_records(merged, _TABLE_COLUMNS["daily_basic"])
def fetch_adj_factor(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
def fetch_adj_factor(
start: date,
end: date,
ts_code: Optional[str] = None,
skip_existing: bool = True,
) -> Iterable[Dict]:
client = _ensure_client()
start_date = _format_date(start)
end_date = _format_date(end)
LOGGER.info("拉取复权因子(%s-%s,股票:%s", start_date, end_date, ts_code or "全部")
df = _fetch_paginated("adj_factor", {
"ts_code": ts_code,
"start_date": start_date,
"end_date": end_date,
})
return _df_to_records(df, _TABLE_COLUMNS["adj_factor"])
LOGGER.info(
"拉取复权因子(%s-%s,股票:%s",
start_date,
end_date,
ts_code or "全部",
extra=LOG_EXTRA,
)
trade_dates = _load_trade_dates(start, end)
frames: List[pd.DataFrame] = []
for trade_date in trade_dates:
if skip_existing and _record_exists("adj_factor", "trade_date", trade_date, ts_code):
LOGGER.debug(
"复权因子已存在,跳过 %s %s",
ts_code or "ALL",
trade_date,
extra=LOG_EXTRA,
)
continue
params = {"trade_date": trade_date}
if ts_code:
params["ts_code"] = ts_code
LOGGER.debug("按交易日拉取复权因子:%s", params, extra=LOG_EXTRA)
df = _fetch_paginated("adj_factor", params)
if not df.empty:
frames.append(df)
if not frames:
return []
merged = pd.concat(frames, ignore_index=True)
return _df_to_records(merged, _TABLE_COLUMNS["adj_factor"])
def fetch_suspensions(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
@ -654,7 +689,7 @@ def ensure_data_coverage(
LOGGER.info("拉取 %s 表数据(股票:%s%s-%s", table, code, start_str, end_str)
try:
kwargs = {"ts_code": code}
if fetch_fn is fetch_daily_basic:
if fetch_fn in (fetch_daily_basic, fetch_adj_factor):
kwargs["skip_existing"] = not force
rows = fetch_fn(start, end, **kwargs)
except Exception:
@ -672,7 +707,7 @@ def ensure_data_coverage(
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
try:
kwargs = {}
if fetch_fn is fetch_daily_basic:
if fetch_fn in (fetch_daily_basic, fetch_adj_factor):
kwargs["skip_existing"] = not force
rows = fetch_fn(start, end, **kwargs)
except Exception: