944 lines
30 KiB
Python
944 lines
30 KiB
Python
"""TuShare 数据拉取与数据覆盖检查工具。"""
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
import time
|
||
from collections import deque
|
||
from dataclasses import dataclass
|
||
from datetime import date
|
||
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple
|
||
|
||
import pandas as pd
|
||
|
||
try:
|
||
import tushare as ts
|
||
except ImportError: # pragma: no cover - 运行时提示
|
||
ts = None # type: ignore[assignment]
|
||
|
||
from app.utils.config import get_config
|
||
from app.utils.db import db_session
|
||
from app.data.schema import initialize_database
|
||
from app.utils.logging import get_logger
|
||
|
||
|
||
LOGGER = get_logger(__name__)
|
||
|
||
API_DEFAULT_LIMIT = 5000
|
||
LOG_EXTRA = {"stage": "data_ingest"}
|
||
|
||
_CALL_QUEUE = deque()
|
||
|
||
|
||
def _normalize_date_str(value: Optional[str]) -> Optional[str]:
|
||
if value is None:
|
||
return None
|
||
text = str(value).strip()
|
||
return text or None
|
||
|
||
|
||
def _respect_rate_limit(cfg) -> None:
|
||
max_calls = cfg.max_calls_per_minute
|
||
if max_calls <= 0:
|
||
return
|
||
now = time.time()
|
||
window = 60.0
|
||
while _CALL_QUEUE and now - _CALL_QUEUE[0] > window:
|
||
_CALL_QUEUE.popleft()
|
||
if len(_CALL_QUEUE) >= max_calls:
|
||
sleep_time = window - (now - _CALL_QUEUE[0]) + 0.1
|
||
LOGGER.debug("触发限频控制,休眠 %.2f 秒", sleep_time, extra=LOG_EXTRA)
|
||
time.sleep(max(0.1, sleep_time))
|
||
_CALL_QUEUE.append(time.time())
|
||
|
||
|
||
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
|
||
if df is None or df.empty:
|
||
return []
|
||
reindexed = df.reindex(columns=allowed_cols)
|
||
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
|
||
|
||
|
||
def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None = None) -> pd.DataFrame:
|
||
client = _ensure_client()
|
||
limit = limit or API_DEFAULT_LIMIT
|
||
frames: List[pd.DataFrame] = []
|
||
offset = 0
|
||
clean_params = {k: v for k, v in params.items() if v is not None}
|
||
LOGGER.info(
|
||
"开始调用 TuShare 接口:%s,参数=%s,limit=%s",
|
||
endpoint,
|
||
clean_params,
|
||
limit,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
while True:
|
||
_respect_rate_limit(get_config())
|
||
call = getattr(client, endpoint)
|
||
try:
|
||
df = call(limit=limit, offset=offset, **clean_params)
|
||
except Exception: # noqa: BLE001
|
||
LOGGER.exception(
|
||
"TuShare 接口调用异常:endpoint=%s offset=%s params=%s",
|
||
endpoint,
|
||
offset,
|
||
clean_params,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
raise
|
||
if df is None or df.empty:
|
||
LOGGER.debug(
|
||
"TuShare 返回空数据:endpoint=%s offset=%s",
|
||
endpoint,
|
||
offset,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
break
|
||
LOGGER.debug(
|
||
"TuShare 返回 %s 行:endpoint=%s offset=%s",
|
||
len(df),
|
||
endpoint,
|
||
offset,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
frames.append(df)
|
||
if len(df) < limit:
|
||
break
|
||
offset += limit
|
||
if not frames:
|
||
return pd.DataFrame()
|
||
merged = pd.concat(frames, ignore_index=True)
|
||
LOGGER.info(
|
||
"TuShare 调用完成:endpoint=%s 总行数=%s",
|
||
endpoint,
|
||
len(merged),
|
||
extra=LOG_EXTRA,
|
||
)
|
||
return merged
|
||
|
||
|
||
@dataclass
|
||
class FetchJob:
|
||
name: str
|
||
start: date
|
||
end: date
|
||
granularity: str = "daily"
|
||
ts_codes: Optional[Sequence[str]] = None
|
||
|
||
|
||
_TABLE_SCHEMAS: Dict[str, str] = {
|
||
"stock_basic": """
|
||
CREATE TABLE IF NOT EXISTS stock_basic (
|
||
ts_code TEXT PRIMARY KEY,
|
||
symbol TEXT,
|
||
name TEXT,
|
||
area TEXT,
|
||
industry TEXT,
|
||
market TEXT,
|
||
exchange TEXT,
|
||
list_status TEXT,
|
||
list_date TEXT,
|
||
delist_date TEXT
|
||
);
|
||
""",
|
||
"daily": """
|
||
CREATE TABLE IF NOT EXISTS daily (
|
||
ts_code TEXT,
|
||
trade_date TEXT,
|
||
open REAL,
|
||
high REAL,
|
||
low REAL,
|
||
close REAL,
|
||
pre_close REAL,
|
||
change REAL,
|
||
pct_chg REAL,
|
||
vol REAL,
|
||
amount REAL,
|
||
PRIMARY KEY (ts_code, trade_date)
|
||
);
|
||
""",
|
||
"daily_basic": """
|
||
CREATE TABLE IF NOT EXISTS daily_basic (
|
||
ts_code TEXT,
|
||
trade_date TEXT,
|
||
close REAL,
|
||
turnover_rate REAL,
|
||
turnover_rate_f REAL,
|
||
volume_ratio REAL,
|
||
pe REAL,
|
||
pe_ttm REAL,
|
||
pb REAL,
|
||
ps REAL,
|
||
ps_ttm REAL,
|
||
dv_ratio REAL,
|
||
dv_ttm REAL,
|
||
total_share REAL,
|
||
float_share REAL,
|
||
free_share REAL,
|
||
total_mv REAL,
|
||
circ_mv REAL,
|
||
PRIMARY KEY (ts_code, trade_date)
|
||
);
|
||
""",
|
||
"adj_factor": """
|
||
CREATE TABLE IF NOT EXISTS adj_factor (
|
||
ts_code TEXT,
|
||
trade_date TEXT,
|
||
adj_factor REAL,
|
||
PRIMARY KEY (ts_code, trade_date)
|
||
);
|
||
""",
|
||
"suspend": """
|
||
CREATE TABLE IF NOT EXISTS suspend (
|
||
ts_code TEXT,
|
||
suspend_date TEXT,
|
||
resume_date TEXT,
|
||
suspend_type TEXT,
|
||
ann_date TEXT,
|
||
suspend_timing TEXT,
|
||
resume_timing TEXT,
|
||
reason TEXT,
|
||
PRIMARY KEY (ts_code, suspend_date)
|
||
);
|
||
""",
|
||
"trade_calendar": """
|
||
CREATE TABLE IF NOT EXISTS trade_calendar (
|
||
exchange TEXT,
|
||
cal_date TEXT,
|
||
is_open INTEGER,
|
||
pretrade_date TEXT,
|
||
PRIMARY KEY (exchange, cal_date)
|
||
);
|
||
""",
|
||
"stk_limit": """
|
||
CREATE TABLE IF NOT EXISTS stk_limit (
|
||
ts_code TEXT,
|
||
trade_date TEXT,
|
||
up_limit REAL,
|
||
down_limit REAL,
|
||
PRIMARY KEY (ts_code, trade_date)
|
||
);
|
||
""",
|
||
}
|
||
|
||
_TABLE_COLUMNS: Dict[str, List[str]] = {
|
||
"stock_basic": [
|
||
"ts_code",
|
||
"symbol",
|
||
"name",
|
||
"area",
|
||
"industry",
|
||
"market",
|
||
"exchange",
|
||
"list_status",
|
||
"list_date",
|
||
"delist_date",
|
||
],
|
||
"daily": [
|
||
"ts_code",
|
||
"trade_date",
|
||
"open",
|
||
"high",
|
||
"low",
|
||
"close",
|
||
"pre_close",
|
||
"change",
|
||
"pct_chg",
|
||
"vol",
|
||
"amount",
|
||
],
|
||
"daily_basic": [
|
||
"ts_code",
|
||
"trade_date",
|
||
"close",
|
||
"turnover_rate",
|
||
"turnover_rate_f",
|
||
"volume_ratio",
|
||
"pe",
|
||
"pe_ttm",
|
||
"pb",
|
||
"ps",
|
||
"ps_ttm",
|
||
"dv_ratio",
|
||
"dv_ttm",
|
||
"total_share",
|
||
"float_share",
|
||
"free_share",
|
||
"total_mv",
|
||
"circ_mv",
|
||
],
|
||
"adj_factor": [
|
||
"ts_code",
|
||
"trade_date",
|
||
"adj_factor",
|
||
],
|
||
"suspend": [
|
||
"ts_code",
|
||
"suspend_date",
|
||
"resume_date",
|
||
"suspend_type",
|
||
"ann_date",
|
||
"suspend_timing",
|
||
"resume_timing",
|
||
"reason",
|
||
],
|
||
"trade_calendar": [
|
||
"exchange",
|
||
"cal_date",
|
||
"is_open",
|
||
"pretrade_date",
|
||
],
|
||
"stk_limit": [
|
||
"ts_code",
|
||
"trade_date",
|
||
"up_limit",
|
||
"down_limit",
|
||
],
|
||
}
|
||
|
||
|
||
def _ensure_client():
|
||
if ts is None:
|
||
raise RuntimeError("未安装 tushare,请先在环境中安装 tushare 包")
|
||
token = get_config().tushare_token or os.getenv("TUSHARE_TOKEN")
|
||
if not token:
|
||
raise RuntimeError("未配置 TuShare Token,请在配置文件或环境变量 TUSHARE_TOKEN 中设置")
|
||
if not hasattr(_ensure_client, "_client") or _ensure_client._client is None: # type: ignore[attr-defined]
|
||
ts.set_token(token)
|
||
_ensure_client._client = ts.pro_api(token) # type: ignore[attr-defined]
|
||
LOGGER.info("完成 TuShare 客户端初始化")
|
||
return _ensure_client._client # type: ignore[attr-defined]
|
||
|
||
|
||
def _format_date(value: date) -> str:
|
||
return value.strftime("%Y%m%d")
|
||
|
||
|
||
def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str]:
|
||
start_str = _format_date(start)
|
||
end_str = _format_date(end)
|
||
query = (
|
||
"SELECT cal_date FROM trade_calendar "
|
||
"WHERE exchange = ? AND cal_date BETWEEN ? AND ? AND is_open = 1 ORDER BY cal_date"
|
||
)
|
||
with db_session(read_only=True) as conn:
|
||
rows = conn.execute(query, (exchange, start_str, end_str)).fetchall()
|
||
return [row["cal_date"] for row in rows]
|
||
|
||
|
||
def _record_exists(
|
||
table: str,
|
||
date_col: str,
|
||
trade_date: str,
|
||
ts_code: Optional[str] = None,
|
||
) -> bool:
|
||
query = f"SELECT 1 FROM {table} WHERE {date_col} = ?"
|
||
params: Tuple = (trade_date,)
|
||
if ts_code:
|
||
query += " AND ts_code = ?"
|
||
params = (trade_date, ts_code)
|
||
with db_session(read_only=True) as conn:
|
||
row = conn.execute(query, params).fetchone()
|
||
return row is not None
|
||
|
||
|
||
def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool:
|
||
start_str = _format_date(start)
|
||
end_str = _format_date(end)
|
||
|
||
effective_start = start_str
|
||
effective_end = end_str
|
||
|
||
if ts_code:
|
||
list_date, delist_date = _listing_window(ts_code)
|
||
if list_date:
|
||
effective_start = max(effective_start, list_date)
|
||
if delist_date:
|
||
effective_end = min(effective_end, delist_date)
|
||
if effective_start > effective_end:
|
||
LOGGER.debug(
|
||
"股票 %s 在目标区间之外,跳过补数",
|
||
ts_code,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
return True
|
||
stats = _range_stats(table, date_col, effective_start, effective_end, ts_code=ts_code)
|
||
else:
|
||
stats = _range_stats(table, date_col, effective_start, effective_end)
|
||
|
||
if stats["min"] is None or stats["max"] is None:
|
||
return False
|
||
if stats["min"] > effective_start or stats["max"] < effective_end:
|
||
return False
|
||
|
||
if ts_code is None:
|
||
expected_days = _expected_trading_days(effective_start, effective_end)
|
||
if expected_days and (stats["distinct"] or 0) < expected_days:
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
def _range_stats(
|
||
table: str,
|
||
date_col: str,
|
||
start_str: str,
|
||
end_str: str,
|
||
ts_code: str | None = None,
|
||
) -> Dict[str, Optional[str]]:
|
||
sql = (
|
||
f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d, "
|
||
f"COUNT(DISTINCT {date_col}) AS distinct_days FROM {table} "
|
||
f"WHERE {date_col} BETWEEN ? AND ?"
|
||
)
|
||
params: List[object] = [start_str, end_str]
|
||
if ts_code:
|
||
sql += " AND ts_code = ?"
|
||
params.append(ts_code)
|
||
with db_session(read_only=True) as conn:
|
||
row = conn.execute(sql, tuple(params)).fetchone()
|
||
return {
|
||
"min": row["min_d"] if row else None,
|
||
"max": row["max_d"] if row else None,
|
||
"distinct": row["distinct_days"] if row else 0,
|
||
}
|
||
|
||
|
||
def _range_needs_refresh(
|
||
table: str,
|
||
date_col: str,
|
||
start_str: str,
|
||
end_str: str,
|
||
expected_days: int = 0,
|
||
) -> bool:
|
||
stats = _range_stats(table, date_col, start_str, end_str)
|
||
if stats["min"] is None or stats["max"] is None:
|
||
return True
|
||
if stats["min"] > start_str or stats["max"] < end_str:
|
||
return True
|
||
if expected_days and (stats["distinct"] or 0) < expected_days:
|
||
return True
|
||
return False
|
||
|
||
|
||
def _listing_window(ts_code: str) -> Tuple[Optional[str], Optional[str]]:
|
||
with db_session(read_only=True) as conn:
|
||
row = conn.execute(
|
||
"SELECT list_date, delist_date FROM stock_basic WHERE ts_code = ?",
|
||
(ts_code,),
|
||
).fetchone()
|
||
if not row:
|
||
return None, None
|
||
return _normalize_date_str(row["list_date"]), _normalize_date_str(row["delist_date"]) # type: ignore[index]
|
||
|
||
|
||
def _calendar_needs_refresh(exchange: str, start_str: str, end_str: str) -> bool:
|
||
sql = """
|
||
SELECT MIN(cal_date) AS min_d, MAX(cal_date) AS max_d, COUNT(*) AS cnt
|
||
FROM trade_calendar
|
||
WHERE exchange = ? AND cal_date BETWEEN ? AND ?
|
||
"""
|
||
with db_session(read_only=True) as conn:
|
||
row = conn.execute(sql, (exchange, start_str, end_str)).fetchone()
|
||
if row is None or row["min_d"] is None:
|
||
return True
|
||
if row["min_d"] > start_str or row["max_d"] < end_str:
|
||
return True
|
||
# 交易日历允许不连续(节假日),此处不比较天数
|
||
return False
|
||
|
||
|
||
def _expected_trading_days(start_str: str, end_str: str, exchange: str = "SSE") -> int:
|
||
sql = """
|
||
SELECT COUNT(*) AS cnt
|
||
FROM trade_calendar
|
||
WHERE exchange = ? AND cal_date BETWEEN ? AND ? AND is_open = 1
|
||
"""
|
||
with db_session(read_only=True) as conn:
|
||
row = conn.execute(sql, (exchange, start_str, end_str)).fetchone()
|
||
return int(row["cnt"]) if row and row["cnt"] is not None else 0
|
||
|
||
|
||
def fetch_stock_basic(exchange: Optional[str] = None, list_status: str = "L") -> Iterable[Dict]:
|
||
client = _ensure_client()
|
||
LOGGER.info(
|
||
"拉取股票基础信息(交易所:%s,状态:%s)",
|
||
exchange or "全部",
|
||
list_status,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
fields = "ts_code,symbol,name,area,industry,market,exchange,list_status,list_date,delist_date"
|
||
df = client.stock_basic(exchange=exchange, list_status=list_status, fields=fields)
|
||
return _df_to_records(df, _TABLE_COLUMNS["stock_basic"])
|
||
|
||
|
||
def fetch_daily_bars(job: FetchJob, skip_existing: bool = True) -> Iterable[Dict]:
|
||
client = _ensure_client()
|
||
frames: List[pd.DataFrame] = []
|
||
|
||
if job.granularity != "daily":
|
||
raise ValueError(f"暂不支持的粒度:{job.granularity}")
|
||
|
||
trade_dates = _load_trade_dates(job.start, job.end)
|
||
if not trade_dates:
|
||
LOGGER.info("本地交易日历缺失,尝试补全后再拉取日线行情", extra=LOG_EXTRA)
|
||
ensure_trade_calendar(job.start, job.end)
|
||
trade_dates = _load_trade_dates(job.start, job.end)
|
||
|
||
if job.ts_codes:
|
||
for code in job.ts_codes:
|
||
for trade_date in trade_dates:
|
||
if skip_existing and _record_exists("daily", "trade_date", trade_date, code):
|
||
LOGGER.debug(
|
||
"日线数据已存在,跳过 %s %s",
|
||
code,
|
||
trade_date,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
continue
|
||
LOGGER.debug(
|
||
"按交易日拉取日线行情:code=%s trade_date=%s",
|
||
code,
|
||
trade_date,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
LOGGER.info(
|
||
"交易日拉取请求:endpoint=daily code=%s trade_date=%s",
|
||
code,
|
||
trade_date,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
df = _fetch_paginated(
|
||
"daily",
|
||
{
|
||
"trade_date": trade_date,
|
||
"ts_code": code,
|
||
},
|
||
)
|
||
if not df.empty:
|
||
frames.append(df)
|
||
else:
|
||
for trade_date in trade_dates:
|
||
if skip_existing and _record_exists("daily", "trade_date", trade_date):
|
||
LOGGER.debug(
|
||
"日线数据已存在,跳过交易日 %s",
|
||
trade_date,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
continue
|
||
LOGGER.debug("按交易日拉取日线行情:%s", trade_date, extra=LOG_EXTRA)
|
||
LOGGER.info(
|
||
"交易日拉取请求:endpoint=daily trade_date=%s",
|
||
trade_date,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
df = _fetch_paginated("daily", {"trade_date": trade_date})
|
||
if not df.empty:
|
||
frames.append(df)
|
||
|
||
if not frames:
|
||
return []
|
||
df = pd.concat(frames, ignore_index=True)
|
||
return _df_to_records(df, _TABLE_COLUMNS["daily"])
|
||
|
||
|
||
def fetch_daily_basic(
|
||
start: date,
|
||
end: date,
|
||
ts_code: Optional[str] = None,
|
||
skip_existing: bool = True,
|
||
) -> Iterable[Dict]:
|
||
client = _ensure_client()
|
||
start_date = _format_date(start)
|
||
end_date = _format_date(end)
|
||
LOGGER.info(
|
||
"拉取日线基础指标(%s-%s,股票:%s)",
|
||
start_date,
|
||
end_date,
|
||
ts_code or "全部",
|
||
extra=LOG_EXTRA,
|
||
)
|
||
|
||
trade_dates = _load_trade_dates(start, end)
|
||
frames: List[pd.DataFrame] = []
|
||
for trade_date in trade_dates:
|
||
if skip_existing and _record_exists("daily_basic", "trade_date", trade_date, ts_code):
|
||
LOGGER.info(
|
||
"日线基础指标已存在,跳过交易日 %s",
|
||
trade_date,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
continue
|
||
params = {"trade_date": trade_date}
|
||
if ts_code:
|
||
params["ts_code"] = ts_code
|
||
LOGGER.info(
|
||
"交易日拉取请求:endpoint=daily_basic params=%s",
|
||
params,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
df = _fetch_paginated("daily_basic", params)
|
||
if not df.empty:
|
||
frames.append(df)
|
||
|
||
if not frames:
|
||
return []
|
||
|
||
merged = pd.concat(frames, ignore_index=True)
|
||
return _df_to_records(merged, _TABLE_COLUMNS["daily_basic"])
|
||
|
||
|
||
def fetch_adj_factor(
|
||
start: date,
|
||
end: date,
|
||
ts_code: Optional[str] = None,
|
||
skip_existing: bool = True,
|
||
) -> Iterable[Dict]:
|
||
client = _ensure_client()
|
||
start_date = _format_date(start)
|
||
end_date = _format_date(end)
|
||
LOGGER.info(
|
||
"拉取复权因子(%s-%s,股票:%s)",
|
||
start_date,
|
||
end_date,
|
||
ts_code or "全部",
|
||
extra=LOG_EXTRA,
|
||
)
|
||
|
||
trade_dates = _load_trade_dates(start, end)
|
||
frames: List[pd.DataFrame] = []
|
||
for trade_date in trade_dates:
|
||
if skip_existing and _record_exists("adj_factor", "trade_date", trade_date, ts_code):
|
||
LOGGER.debug(
|
||
"复权因子已存在,跳过 %s %s",
|
||
ts_code or "ALL",
|
||
trade_date,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
continue
|
||
params = {"trade_date": trade_date}
|
||
if ts_code:
|
||
params["ts_code"] = ts_code
|
||
LOGGER.info("交易日拉取请求:endpoint=adj_factor params=%s", params, extra=LOG_EXTRA)
|
||
df = _fetch_paginated("adj_factor", params)
|
||
if not df.empty:
|
||
frames.append(df)
|
||
|
||
if not frames:
|
||
return []
|
||
|
||
merged = pd.concat(frames, ignore_index=True)
|
||
return _df_to_records(merged, _TABLE_COLUMNS["adj_factor"])
|
||
|
||
|
||
def fetch_suspensions(
|
||
start: date,
|
||
end: date,
|
||
ts_code: Optional[str] = None,
|
||
skip_existing: bool = True,
|
||
) -> Iterable[Dict]:
|
||
client = _ensure_client()
|
||
start_date = _format_date(start)
|
||
end_date = _format_date(end)
|
||
LOGGER.info("拉取停复牌信息(%s-%s)", start_date, end_date, extra=LOG_EXTRA)
|
||
trade_dates = _load_trade_dates(start, end)
|
||
frames: List[pd.DataFrame] = []
|
||
for trade_date in trade_dates:
|
||
if skip_existing and _record_exists("suspend", "suspend_date", trade_date, ts_code):
|
||
LOGGER.debug(
|
||
"停复牌信息已存在,跳过 %s %s",
|
||
ts_code or "ALL",
|
||
trade_date,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
continue
|
||
params = {"trade_date": trade_date}
|
||
if ts_code:
|
||
params["ts_code"] = ts_code
|
||
LOGGER.info("交易日拉取请求:endpoint=suspend_d params=%s", params, extra=LOG_EXTRA)
|
||
df = _fetch_paginated("suspend_d", params, limit=2000)
|
||
if not df.empty:
|
||
frames.append(df)
|
||
|
||
if not frames:
|
||
return []
|
||
|
||
merged = pd.concat(frames, ignore_index=True)
|
||
return _df_to_records(merged, _TABLE_COLUMNS["suspend"])
|
||
|
||
|
||
def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]:
|
||
client = _ensure_client()
|
||
start_date = _format_date(start)
|
||
end_date = _format_date(end)
|
||
LOGGER.info(
|
||
"拉取交易日历(交易所:%s,区间:%s-%s)",
|
||
exchange,
|
||
start_date,
|
||
end_date,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
df = client.trade_cal(exchange=exchange, start_date=start_date, end_date=end_date)
|
||
if df is not None and not df.empty and "is_open" in df.columns:
|
||
df["is_open"] = pd.to_numeric(df["is_open"], errors="coerce").fillna(0).astype(int)
|
||
return _df_to_records(df, _TABLE_COLUMNS["trade_calendar"])
|
||
|
||
|
||
def fetch_stk_limit(
|
||
start: date,
|
||
end: date,
|
||
ts_code: Optional[str] = None,
|
||
skip_existing: bool = True,
|
||
) -> Iterable[Dict]:
|
||
client = _ensure_client()
|
||
start_date = _format_date(start)
|
||
end_date = _format_date(end)
|
||
LOGGER.info("拉取涨跌停价格(%s-%s)", start_date, end_date, extra=LOG_EXTRA)
|
||
trade_dates = _load_trade_dates(start, end)
|
||
frames: List[pd.DataFrame] = []
|
||
for trade_date in trade_dates:
|
||
if skip_existing and _record_exists("stk_limit", "trade_date", trade_date, ts_code):
|
||
LOGGER.debug(
|
||
"涨跌停数据已存在,跳过 %s %s",
|
||
ts_code or "ALL",
|
||
trade_date,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
continue
|
||
params = {"trade_date": trade_date}
|
||
if ts_code:
|
||
params["ts_code"] = ts_code
|
||
LOGGER.info("交易日拉取请求:endpoint=stk_limit params=%s", params, extra=LOG_EXTRA)
|
||
df = _fetch_paginated("stk_limit", params)
|
||
if not df.empty:
|
||
frames.append(df)
|
||
|
||
if not frames:
|
||
return []
|
||
|
||
merged = pd.concat(frames, ignore_index=True)
|
||
return _df_to_records(merged, _TABLE_COLUMNS["stk_limit"])
|
||
|
||
|
||
def save_records(table: str, rows: Iterable[Dict]) -> None:
|
||
items = list(rows)
|
||
if not items:
|
||
LOGGER.info("表 %s 没有新增记录,跳过写入", table, extra=LOG_EXTRA)
|
||
return
|
||
|
||
schema = _TABLE_SCHEMAS.get(table)
|
||
columns = _TABLE_COLUMNS.get(table)
|
||
if not schema or not columns:
|
||
raise ValueError(f"不支持写入的表:{table}")
|
||
|
||
placeholders = ",".join([f":{col}" for col in columns])
|
||
col_clause = ",".join(columns)
|
||
|
||
LOGGER.info("表 %s 写入 %d 条记录", table, len(items), extra=LOG_EXTRA)
|
||
with db_session() as conn:
|
||
conn.executescript(schema)
|
||
conn.executemany(
|
||
f"INSERT OR REPLACE INTO {table} ({col_clause}) VALUES ({placeholders})",
|
||
items,
|
||
)
|
||
|
||
|
||
def ensure_stock_basic(list_status: str = "L") -> None:
|
||
exchanges = ("SSE", "SZSE")
|
||
with db_session(read_only=True) as conn:
|
||
row = conn.execute(
|
||
"SELECT COUNT(*) AS cnt FROM stock_basic WHERE exchange IN (?, ?) AND list_status = ?",
|
||
(*exchanges, list_status),
|
||
).fetchone()
|
||
if row and row["cnt"]:
|
||
LOGGER.info(
|
||
"股票基础信息已存在 %d 条记录,跳过拉取",
|
||
row["cnt"],
|
||
extra=LOG_EXTRA,
|
||
)
|
||
return
|
||
|
||
for exch in exchanges:
|
||
save_records("stock_basic", fetch_stock_basic(exchange=exch, list_status=list_status))
|
||
|
||
|
||
def ensure_trade_calendar(start: date, end: date, exchanges: Sequence[str] = ("SSE", "SZSE")) -> None:
|
||
start_str = _format_date(start)
|
||
end_str = _format_date(end)
|
||
for exch in exchanges:
|
||
if _calendar_needs_refresh(exch, start_str, end_str):
|
||
save_records("trade_calendar", fetch_trade_calendar(start, end, exchange=exch))
|
||
|
||
|
||
def ensure_data_coverage(
|
||
start: date,
|
||
end: date,
|
||
ts_codes: Optional[Sequence[str]] = None,
|
||
include_limits: bool = True,
|
||
force: bool = False,
|
||
progress_hook: Callable[[str, float], None] | None = None,
|
||
) -> None:
|
||
initialize_database()
|
||
start_str = _format_date(start)
|
||
end_str = _format_date(end)
|
||
|
||
total_steps = 5 + (1 if include_limits else 0)
|
||
current_step = 0
|
||
|
||
def advance(message: str) -> None:
|
||
nonlocal current_step
|
||
current_step += 1
|
||
progress = min(current_step / total_steps, 1.0)
|
||
if progress_hook:
|
||
progress_hook(message, progress)
|
||
LOGGER.info(message, extra=LOG_EXTRA)
|
||
|
||
advance("准备股票基础信息与交易日历")
|
||
ensure_stock_basic()
|
||
ensure_trade_calendar(start, end)
|
||
|
||
codes = tuple(dict.fromkeys(ts_codes)) if ts_codes else tuple()
|
||
expected_days = _expected_trading_days(start_str, end_str)
|
||
|
||
advance("处理日线行情数据")
|
||
if codes:
|
||
pending_codes: List[str] = []
|
||
for code in codes:
|
||
if not force and _should_skip_range("daily", "trade_date", start, end, code):
|
||
LOGGER.info("股票 %s 的日线已覆盖 %s-%s,跳过", code, start_str, end_str)
|
||
continue
|
||
pending_codes.append(code)
|
||
if pending_codes:
|
||
job = FetchJob("daily_autofill", start=start, end=end, ts_codes=tuple(pending_codes))
|
||
LOGGER.info("开始拉取日线行情:%s-%s(待补股票 %d 支)", start_str, end_str, len(pending_codes))
|
||
save_records("daily", fetch_daily_bars(job, skip_existing=not force))
|
||
else:
|
||
needs_daily = force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days)
|
||
if not needs_daily:
|
||
LOGGER.info("日线数据已覆盖 %s-%s,跳过拉取", start_str, end_str)
|
||
else:
|
||
job = FetchJob("daily_autofill", start=start, end=end)
|
||
LOGGER.info("开始拉取日线行情:%s-%s", start_str, end_str)
|
||
save_records("daily", fetch_daily_bars(job, skip_existing=not force))
|
||
|
||
date_cols = {
|
||
"daily_basic": "trade_date",
|
||
"adj_factor": "trade_date",
|
||
"stk_limit": "trade_date",
|
||
"suspend": "suspend_date",
|
||
}
|
||
|
||
def _save_with_codes(table: str, fetch_fn) -> None:
|
||
date_col = date_cols.get(table, "trade_date")
|
||
if codes:
|
||
for code in codes:
|
||
if not force and _should_skip_range(table, date_col, start, end, code):
|
||
LOGGER.info("表 %s 股票 %s 已覆盖 %s-%s,跳过", table, code, start_str, end_str)
|
||
continue
|
||
LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str)
|
||
try:
|
||
kwargs = {"ts_code": code}
|
||
if fetch_fn in (fetch_daily_basic, fetch_adj_factor, fetch_suspensions, fetch_stk_limit):
|
||
kwargs["skip_existing"] = not force
|
||
rows = fetch_fn(start, end, **kwargs)
|
||
except Exception:
|
||
LOGGER.exception("TuShare 拉取失败:table=%s code=%s", table, code)
|
||
raise
|
||
save_records(table, rows)
|
||
else:
|
||
needs_refresh = force
|
||
if not force:
|
||
expected = expected_days if table in {"daily_basic", "adj_factor", "stk_limit"} else 0
|
||
needs_refresh = _range_needs_refresh(table, date_col, start_str, end_str, expected)
|
||
if not needs_refresh:
|
||
LOGGER.info("表 %s 已覆盖 %s-%s,跳过", table, start_str, end_str)
|
||
return
|
||
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
|
||
try:
|
||
kwargs = {}
|
||
if fetch_fn in (fetch_daily_basic, fetch_adj_factor, fetch_suspensions, fetch_stk_limit):
|
||
kwargs["skip_existing"] = not force
|
||
rows = fetch_fn(start, end, **kwargs)
|
||
except Exception:
|
||
LOGGER.exception("TuShare 拉取失败:table=%s code=全部", table)
|
||
raise
|
||
save_records(table, rows)
|
||
|
||
advance("处理日线基础指标数据")
|
||
_save_with_codes("daily_basic", fetch_daily_basic)
|
||
|
||
advance("处理复权因子数据")
|
||
_save_with_codes("adj_factor", fetch_adj_factor)
|
||
|
||
if include_limits:
|
||
advance("处理涨跌停价格数据")
|
||
_save_with_codes("stk_limit", fetch_stk_limit)
|
||
|
||
advance("处理停复牌信息")
|
||
_save_with_codes("suspend", fetch_suspensions)
|
||
|
||
if progress_hook:
|
||
progress_hook("数据覆盖检查完成", 1.0)
|
||
|
||
|
||
def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]]:
|
||
start_str = _format_date(start)
|
||
end_str = _format_date(end)
|
||
expected_days = _expected_trading_days(start_str, end_str)
|
||
|
||
coverage: Dict[str, Dict[str, object]] = {
|
||
"period": {
|
||
"start": start_str,
|
||
"end": end_str,
|
||
"expected_trading_days": expected_days,
|
||
}
|
||
}
|
||
|
||
def add_table(name: str, date_col: str, require_days: bool = True) -> None:
|
||
stats = _range_stats(name, date_col, start_str, end_str)
|
||
coverage[name] = {
|
||
"min": stats["min"],
|
||
"max": stats["max"],
|
||
"distinct_days": stats["distinct"],
|
||
"meets_expectation": (
|
||
stats["min"] is not None
|
||
and stats["max"] is not None
|
||
and stats["min"] <= start_str
|
||
and stats["max"] >= end_str
|
||
and ((not require_days) or (stats["distinct"] or 0) >= expected_days)
|
||
),
|
||
}
|
||
|
||
add_table("daily", "trade_date")
|
||
add_table("daily_basic", "trade_date")
|
||
add_table("adj_factor", "trade_date")
|
||
add_table("stk_limit", "trade_date")
|
||
add_table("suspend", "suspend_date", require_days=False)
|
||
|
||
with db_session(read_only=True) as conn:
|
||
stock_tot = conn.execute("SELECT COUNT(*) AS cnt FROM stock_basic").fetchone()
|
||
stock_sse = conn.execute(
|
||
"SELECT COUNT(*) AS cnt FROM stock_basic WHERE exchange = 'SSE' AND list_status = 'L'"
|
||
).fetchone()
|
||
stock_szse = conn.execute(
|
||
"SELECT COUNT(*) AS cnt FROM stock_basic WHERE exchange = 'SZSE' AND list_status = 'L'"
|
||
).fetchone()
|
||
coverage["stock_basic"] = {
|
||
"total": stock_tot["cnt"] if stock_tot else 0,
|
||
"sse_listed": stock_sse["cnt"] if stock_sse else 0,
|
||
"szse_listed": stock_szse["cnt"] if stock_szse else 0,
|
||
}
|
||
|
||
return coverage
|
||
|
||
|
||
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
|
||
LOGGER.info("启动 TuShare 拉取任务:%s", job.name, extra=LOG_EXTRA)
|
||
ensure_data_coverage(
|
||
job.start,
|
||
job.end,
|
||
ts_codes=job.ts_codes,
|
||
include_limits=include_limits,
|
||
force=True,
|
||
)
|
||
LOGGER.info("任务 %s 完成", job.name, extra=LOG_EXTRA)
|