505 lines
16 KiB
Python
505 lines
16 KiB
Python
"""TuShare 数据拉取与数据覆盖检查工具。"""
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import os
|
||
from dataclasses import dataclass
|
||
from datetime import date
|
||
from typing import Dict, Iterable, List, Optional, Sequence
|
||
|
||
import pandas as pd
|
||
|
||
try:
|
||
import tushare as ts
|
||
except ImportError: # pragma: no cover - 运行时提示
|
||
ts = None # type: ignore[assignment]
|
||
|
||
from app.utils.config import get_config
|
||
from app.utils.db import db_session
|
||
from app.data.schema import initialize_database
|
||
|
||
LOGGER = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class FetchJob:
|
||
name: str
|
||
start: date
|
||
end: date
|
||
granularity: str = "daily"
|
||
ts_codes: Optional[Sequence[str]] = None
|
||
|
||
|
||
_TABLE_SCHEMAS: Dict[str, str] = {
|
||
"stock_basic": """
|
||
CREATE TABLE IF NOT EXISTS stock_basic (
|
||
ts_code TEXT PRIMARY KEY,
|
||
symbol TEXT,
|
||
name TEXT,
|
||
area TEXT,
|
||
industry TEXT,
|
||
market TEXT,
|
||
exchange TEXT,
|
||
list_status TEXT,
|
||
list_date TEXT,
|
||
delist_date TEXT
|
||
);
|
||
""",
|
||
"daily": """
|
||
CREATE TABLE IF NOT EXISTS daily (
|
||
ts_code TEXT,
|
||
trade_date TEXT,
|
||
open REAL,
|
||
high REAL,
|
||
low REAL,
|
||
close REAL,
|
||
pre_close REAL,
|
||
change REAL,
|
||
pct_chg REAL,
|
||
vol REAL,
|
||
amount REAL,
|
||
PRIMARY KEY (ts_code, trade_date)
|
||
);
|
||
""",
|
||
"daily_basic": """
|
||
CREATE TABLE IF NOT EXISTS daily_basic (
|
||
ts_code TEXT,
|
||
trade_date TEXT,
|
||
close REAL,
|
||
turnover_rate REAL,
|
||
turnover_rate_f REAL,
|
||
volume_ratio REAL,
|
||
pe REAL,
|
||
pe_ttm REAL,
|
||
pb REAL,
|
||
ps REAL,
|
||
ps_ttm REAL,
|
||
dv_ratio REAL,
|
||
dv_ttm REAL,
|
||
total_share REAL,
|
||
float_share REAL,
|
||
free_share REAL,
|
||
total_mv REAL,
|
||
circ_mv REAL,
|
||
PRIMARY KEY (ts_code, trade_date)
|
||
);
|
||
""",
|
||
"adj_factor": """
|
||
CREATE TABLE IF NOT EXISTS adj_factor (
|
||
ts_code TEXT,
|
||
trade_date TEXT,
|
||
adj_factor REAL,
|
||
PRIMARY KEY (ts_code, trade_date)
|
||
);
|
||
""",
|
||
"suspend": """
|
||
CREATE TABLE IF NOT EXISTS suspend (
|
||
ts_code TEXT,
|
||
suspend_date TEXT,
|
||
resume_date TEXT,
|
||
suspend_type TEXT,
|
||
ann_date TEXT,
|
||
suspend_timing TEXT,
|
||
resume_timing TEXT,
|
||
reason TEXT,
|
||
PRIMARY KEY (ts_code, suspend_date)
|
||
);
|
||
""",
|
||
"trade_calendar": """
|
||
CREATE TABLE IF NOT EXISTS trade_calendar (
|
||
exchange TEXT,
|
||
cal_date TEXT,
|
||
is_open INTEGER,
|
||
pretrade_date TEXT,
|
||
PRIMARY KEY (exchange, cal_date)
|
||
);
|
||
""",
|
||
"stk_limit": """
|
||
CREATE TABLE IF NOT EXISTS stk_limit (
|
||
ts_code TEXT,
|
||
trade_date TEXT,
|
||
up_limit REAL,
|
||
down_limit REAL,
|
||
PRIMARY KEY (ts_code, trade_date)
|
||
);
|
||
""",
|
||
}
|
||
|
||
_TABLE_COLUMNS: Dict[str, List[str]] = {
|
||
"stock_basic": [
|
||
"ts_code",
|
||
"symbol",
|
||
"name",
|
||
"area",
|
||
"industry",
|
||
"market",
|
||
"exchange",
|
||
"list_status",
|
||
"list_date",
|
||
"delist_date",
|
||
],
|
||
"daily": [
|
||
"ts_code",
|
||
"trade_date",
|
||
"open",
|
||
"high",
|
||
"low",
|
||
"close",
|
||
"pre_close",
|
||
"change",
|
||
"pct_chg",
|
||
"vol",
|
||
"amount",
|
||
],
|
||
"daily_basic": [
|
||
"ts_code",
|
||
"trade_date",
|
||
"close",
|
||
"turnover_rate",
|
||
"turnover_rate_f",
|
||
"volume_ratio",
|
||
"pe",
|
||
"pe_ttm",
|
||
"pb",
|
||
"ps",
|
||
"ps_ttm",
|
||
"dv_ratio",
|
||
"dv_ttm",
|
||
"total_share",
|
||
"float_share",
|
||
"free_share",
|
||
"total_mv",
|
||
"circ_mv",
|
||
],
|
||
"adj_factor": [
|
||
"ts_code",
|
||
"trade_date",
|
||
"adj_factor",
|
||
],
|
||
"suspend": [
|
||
"ts_code",
|
||
"suspend_date",
|
||
"resume_date",
|
||
"suspend_type",
|
||
"ann_date",
|
||
"suspend_timing",
|
||
"resume_timing",
|
||
"reason",
|
||
],
|
||
"trade_calendar": [
|
||
"exchange",
|
||
"cal_date",
|
||
"is_open",
|
||
"pretrade_date",
|
||
],
|
||
"stk_limit": [
|
||
"ts_code",
|
||
"trade_date",
|
||
"up_limit",
|
||
"down_limit",
|
||
],
|
||
}
|
||
|
||
|
||
def _ensure_client():
|
||
if ts is None:
|
||
raise RuntimeError("未安装 tushare,请先在环境中安装 tushare 包")
|
||
token = get_config().tushare_token or os.getenv("TUSHARE_TOKEN")
|
||
if not token:
|
||
raise RuntimeError("未配置 TuShare Token,请在配置文件或环境变量 TUSHARE_TOKEN 中设置")
|
||
if not hasattr(_ensure_client, "_client") or _ensure_client._client is None: # type: ignore[attr-defined]
|
||
ts.set_token(token)
|
||
_ensure_client._client = ts.pro_api(token) # type: ignore[attr-defined]
|
||
LOGGER.info("完成 TuShare 客户端初始化")
|
||
return _ensure_client._client # type: ignore[attr-defined]
|
||
|
||
|
||
def _format_date(value: date) -> str:
|
||
return value.strftime("%Y%m%d")
|
||
|
||
|
||
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
|
||
if df is None or df.empty:
|
||
return []
|
||
reindexed = df.reindex(columns=allowed_cols)
|
||
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
|
||
|
||
|
||
def _range_stats(table: str, date_col: str, start_str: str, end_str: str) -> Dict[str, Optional[str]]:
|
||
sql = (
|
||
f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d, "
|
||
f"COUNT(DISTINCT {date_col}) AS distinct_days FROM {table} "
|
||
f"WHERE {date_col} BETWEEN ? AND ?"
|
||
)
|
||
with db_session(read_only=True) as conn:
|
||
row = conn.execute(sql, (start_str, end_str)).fetchone()
|
||
return {
|
||
"min": row["min_d"],
|
||
"max": row["max_d"],
|
||
"distinct": row["distinct_days"] if row else 0,
|
||
}
|
||
|
||
|
||
def _range_needs_refresh(
|
||
table: str,
|
||
date_col: str,
|
||
start_str: str,
|
||
end_str: str,
|
||
expected_days: int = 0,
|
||
) -> bool:
|
||
stats = _range_stats(table, date_col, start_str, end_str)
|
||
if stats["min"] is None or stats["max"] is None:
|
||
return True
|
||
if stats["min"] > start_str or stats["max"] < end_str:
|
||
return True
|
||
if expected_days and (stats["distinct"] or 0) < expected_days:
|
||
return True
|
||
return False
|
||
|
||
|
||
def _calendar_needs_refresh(exchange: str, start_str: str, end_str: str) -> bool:
|
||
sql = """
|
||
SELECT MIN(cal_date) AS min_d, MAX(cal_date) AS max_d, COUNT(*) AS cnt
|
||
FROM trade_calendar
|
||
WHERE exchange = ? AND cal_date BETWEEN ? AND ?
|
||
"""
|
||
with db_session(read_only=True) as conn:
|
||
row = conn.execute(sql, (exchange, start_str, end_str)).fetchone()
|
||
if row is None or row["min_d"] is None:
|
||
return True
|
||
if row["min_d"] > start_str or row["max_d"] < end_str:
|
||
return True
|
||
# 交易日历允许不连续(节假日),此处不比较天数
|
||
return False
|
||
|
||
|
||
def _expected_trading_days(start_str: str, end_str: str, exchange: str = "SSE") -> int:
|
||
sql = """
|
||
SELECT COUNT(*) AS cnt
|
||
FROM trade_calendar
|
||
WHERE exchange = ? AND cal_date BETWEEN ? AND ? AND is_open = 1
|
||
"""
|
||
with db_session(read_only=True) as conn:
|
||
row = conn.execute(sql, (exchange, start_str, end_str)).fetchone()
|
||
return int(row["cnt"]) if row and row["cnt"] is not None else 0
|
||
|
||
|
||
def fetch_stock_basic(exchange: Optional[str] = None, list_status: str = "L") -> Iterable[Dict]:
|
||
client = _ensure_client()
|
||
LOGGER.info("拉取股票基础信息(交易所:%s,状态:%s)", exchange or "全部", list_status)
|
||
fields = "ts_code,symbol,name,area,industry,market,exchange,list_status,list_date,delist_date"
|
||
df = client.stock_basic(exchange=exchange, list_status=list_status, fields=fields)
|
||
return _df_to_records(df, _TABLE_COLUMNS["stock_basic"])
|
||
|
||
|
||
def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]:
|
||
client = _ensure_client()
|
||
start_date = _format_date(job.start)
|
||
end_date = _format_date(job.end)
|
||
frames: List[pd.DataFrame] = []
|
||
|
||
if job.granularity != "daily":
|
||
raise ValueError(f"暂不支持的粒度:{job.granularity}")
|
||
|
||
if job.ts_codes:
|
||
for code in job.ts_codes:
|
||
LOGGER.info("拉取 %s 的日线行情(%s-%s)", code, start_date, end_date)
|
||
frames.append(client.daily(ts_code=code, start_date=start_date, end_date=end_date))
|
||
else:
|
||
LOGGER.info("按全市场拉取日线行情(%s-%s)", start_date, end_date)
|
||
frames.append(client.daily(start_date=start_date, end_date=end_date))
|
||
|
||
if not frames:
|
||
return []
|
||
df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
|
||
return _df_to_records(df, _TABLE_COLUMNS["daily"])
|
||
|
||
|
||
def fetch_daily_basic(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
|
||
client = _ensure_client()
|
||
start_date = _format_date(start)
|
||
end_date = _format_date(end)
|
||
LOGGER.info("拉取日线基础指标(%s-%s,股票:%s)", start_date, end_date, ts_code or "全部")
|
||
df = client.daily_basic(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||
return _df_to_records(df, _TABLE_COLUMNS["daily_basic"])
|
||
|
||
|
||
def fetch_adj_factor(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
|
||
client = _ensure_client()
|
||
start_date = _format_date(start)
|
||
end_date = _format_date(end)
|
||
LOGGER.info("拉取复权因子(%s-%s,股票:%s)", start_date, end_date, ts_code or "全部")
|
||
df = client.adj_factor(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||
return _df_to_records(df, _TABLE_COLUMNS["adj_factor"])
|
||
|
||
|
||
def fetch_suspensions(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
|
||
client = _ensure_client()
|
||
start_date = _format_date(start)
|
||
end_date = _format_date(end)
|
||
LOGGER.info("拉取停复牌信息(%s-%s)", start_date, end_date)
|
||
df = client.suspend_d(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||
return _df_to_records(df, _TABLE_COLUMNS["suspend"])
|
||
|
||
|
||
def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]:
|
||
client = _ensure_client()
|
||
start_date = _format_date(start)
|
||
end_date = _format_date(end)
|
||
LOGGER.info("拉取交易日历(交易所:%s,区间:%s-%s)", exchange, start_date, end_date)
|
||
df = client.trade_cal(exchange=exchange, start_date=start_date, end_date=end_date)
|
||
return _df_to_records(df, _TABLE_COLUMNS["trade_calendar"])
|
||
|
||
|
||
def fetch_stk_limit(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
|
||
client = _ensure_client()
|
||
start_date = _format_date(start)
|
||
end_date = _format_date(end)
|
||
LOGGER.info("拉取涨跌停价格(%s-%s)", start_date, end_date)
|
||
df = client.stk_limit(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||
return _df_to_records(df, _TABLE_COLUMNS["stk_limit"])
|
||
|
||
|
||
def save_records(table: str, rows: Iterable[Dict]) -> None:
|
||
items = list(rows)
|
||
if not items:
|
||
LOGGER.info("表 %s 没有新增记录,跳过写入", table)
|
||
return
|
||
|
||
schema = _TABLE_SCHEMAS.get(table)
|
||
columns = _TABLE_COLUMNS.get(table)
|
||
if not schema or not columns:
|
||
raise ValueError(f"不支持写入的表:{table}")
|
||
|
||
placeholders = ",".join([f":{col}" for col in columns])
|
||
col_clause = ",".join(columns)
|
||
|
||
LOGGER.info("表 %s 写入 %d 条记录", table, len(items))
|
||
with db_session() as conn:
|
||
conn.executescript(schema)
|
||
conn.executemany(
|
||
f"INSERT OR REPLACE INTO {table} ({col_clause}) VALUES ({placeholders})",
|
||
items,
|
||
)
|
||
|
||
|
||
def ensure_stock_basic(list_status: str = "L") -> None:
|
||
exchanges = ("SSE", "SZSE")
|
||
with db_session(read_only=True) as conn:
|
||
row = conn.execute(
|
||
"SELECT COUNT(*) AS cnt FROM stock_basic WHERE exchange IN (?, ?) AND list_status = ?",
|
||
(*exchanges, list_status),
|
||
).fetchone()
|
||
if row and row["cnt"]:
|
||
LOGGER.info("股票基础信息已存在 %d 条记录,跳过拉取", row["cnt"])
|
||
return
|
||
|
||
for exch in exchanges:
|
||
save_records("stock_basic", fetch_stock_basic(exchange=exch, list_status=list_status))
|
||
|
||
|
||
def ensure_trade_calendar(start: date, end: date, exchanges: Sequence[str] = ("SSE", "SZSE")) -> None:
|
||
start_str = _format_date(start)
|
||
end_str = _format_date(end)
|
||
for exch in exchanges:
|
||
if _calendar_needs_refresh(exch, start_str, end_str):
|
||
save_records("trade_calendar", fetch_trade_calendar(start, end, exchange=exch))
|
||
|
||
|
||
def ensure_data_coverage(
|
||
start: date,
|
||
end: date,
|
||
ts_codes: Optional[Sequence[str]] = None,
|
||
include_limits: bool = True,
|
||
force: bool = False,
|
||
) -> None:
|
||
initialize_database()
|
||
start_str = _format_date(start)
|
||
end_str = _format_date(end)
|
||
|
||
ensure_stock_basic()
|
||
ensure_trade_calendar(start, end)
|
||
|
||
codes = tuple(dict.fromkeys(ts_codes)) if ts_codes else tuple()
|
||
expected_days = _expected_trading_days(start_str, end_str)
|
||
job = FetchJob("daily_autofill", start=start, end=end, ts_codes=codes)
|
||
|
||
if force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days):
|
||
save_records("daily", fetch_daily_bars(job))
|
||
|
||
def _save_with_codes(table: str, fetch_fn) -> None:
|
||
if codes:
|
||
for code in codes:
|
||
save_records(table, fetch_fn(start, end, ts_code=code))
|
||
else:
|
||
save_records(table, fetch_fn(start, end))
|
||
|
||
if force or _range_needs_refresh("daily_basic", "trade_date", start_str, end_str, expected_days):
|
||
_save_with_codes("daily_basic", fetch_daily_basic)
|
||
|
||
if force or _range_needs_refresh("adj_factor", "trade_date", start_str, end_str, expected_days):
|
||
_save_with_codes("adj_factor", fetch_adj_factor)
|
||
|
||
if include_limits and (force or _range_needs_refresh("stk_limit", "trade_date", start_str, end_str, expected_days)):
|
||
_save_with_codes("stk_limit", fetch_stk_limit)
|
||
|
||
if force or _range_needs_refresh("suspend", "suspend_date", start_str, end_str):
|
||
_save_with_codes("suspend", fetch_suspensions)
|
||
|
||
|
||
def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]]:
|
||
start_str = _format_date(start)
|
||
end_str = _format_date(end)
|
||
expected_days = _expected_trading_days(start_str, end_str)
|
||
|
||
coverage: Dict[str, Dict[str, object]] = {
|
||
"period": {
|
||
"start": start_str,
|
||
"end": end_str,
|
||
"expected_trading_days": expected_days,
|
||
}
|
||
}
|
||
|
||
def add_table(name: str, date_col: str, require_days: bool = True) -> None:
|
||
stats = _range_stats(name, date_col, start_str, end_str)
|
||
coverage[name] = {
|
||
"min": stats["min"],
|
||
"max": stats["max"],
|
||
"distinct_days": stats["distinct"],
|
||
"meets_expectation": (
|
||
stats["min"] is not None
|
||
and stats["max"] is not None
|
||
and stats["min"] <= start_str
|
||
and stats["max"] >= end_str
|
||
and ((not require_days) or (stats["distinct"] or 0) >= expected_days)
|
||
),
|
||
}
|
||
|
||
add_table("daily", "trade_date")
|
||
add_table("daily_basic", "trade_date")
|
||
add_table("adj_factor", "trade_date")
|
||
add_table("stk_limit", "trade_date")
|
||
add_table("suspend", "suspend_date", require_days=False)
|
||
|
||
with db_session(read_only=True) as conn:
|
||
stock_tot = conn.execute("SELECT COUNT(*) AS cnt FROM stock_basic").fetchone()
|
||
stock_sse = conn.execute(
|
||
"SELECT COUNT(*) AS cnt FROM stock_basic WHERE exchange = 'SSE' AND list_status = 'L'"
|
||
).fetchone()
|
||
stock_szse = conn.execute(
|
||
"SELECT COUNT(*) AS cnt FROM stock_basic WHERE exchange = 'SZSE' AND list_status = 'L'"
|
||
).fetchone()
|
||
coverage["stock_basic"] = {
|
||
"total": stock_tot["cnt"] if stock_tot else 0,
|
||
"sse_listed": stock_sse["cnt"] if stock_sse else 0,
|
||
"szse_listed": stock_szse["cnt"] if stock_szse else 0,
|
||
}
|
||
|
||
return coverage
|
||
|
||
|
||
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
|
||
LOGGER.info("启动 TuShare 拉取任务:%s", job.name)
|
||
ensure_data_coverage(job.start, job.end, ts_codes=job.ts_codes, include_limits=include_limits, force=True)
|
||
LOGGER.info("任务 %s 完成", job.name)
|