llm-quant/app/ingest/tushare.py
2025-09-26 18:21:25 +08:00

241 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""TuShare 数据拉取管线实现。"""
from __future__ import annotations
import logging
import os
from dataclasses import dataclass
from datetime import date
from typing import Dict, Iterable, List, Optional, Sequence
import pandas as pd
try:
import tushare as ts
except ImportError as exc: # pragma: no cover - dependency error surfaced at runtime
ts = None # type: ignore[assignment]
from app.utils.config import get_config
from app.utils.db import db_session
LOGGER = logging.getLogger(__name__)
@dataclass
class FetchJob:
name: str
start: date
end: date
granularity: str = "daily"
ts_codes: Optional[Sequence[str]] = None
_TABLE_SCHEMAS: Dict[str, str] = {
"daily": """
CREATE TABLE IF NOT EXISTS daily (
ts_code TEXT,
trade_date TEXT,
open REAL,
high REAL,
low REAL,
close REAL,
pre_close REAL,
change REAL,
pct_chg REAL,
vol REAL,
amount REAL,
PRIMARY KEY (ts_code, trade_date)
);
""",
"suspend": """
CREATE TABLE IF NOT EXISTS suspend (
ts_code TEXT,
suspend_date TEXT,
resume_date TEXT,
suspend_type TEXT,
ann_date TEXT,
suspend_timing TEXT,
resume_timing TEXT,
reason TEXT,
PRIMARY KEY (ts_code, suspend_date)
);
""",
"trade_calendar": """
CREATE TABLE IF NOT EXISTS trade_calendar (
exchange TEXT,
cal_date TEXT PRIMARY KEY,
is_open INTEGER,
pretrade_date TEXT
);
""",
"stk_limit": """
CREATE TABLE IF NOT EXISTS stk_limit (
ts_code TEXT,
trade_date TEXT,
up_limit REAL,
down_limit REAL,
PRIMARY KEY (ts_code, trade_date)
);
""",
}
_TABLE_COLUMNS: Dict[str, List[str]] = {
"daily": [
"ts_code",
"trade_date",
"open",
"high",
"low",
"close",
"pre_close",
"change",
"pct_chg",
"vol",
"amount",
],
"suspend": [
"ts_code",
"suspend_date",
"resume_date",
"suspend_type",
"ann_date",
"suspend_timing",
"resume_timing",
"reason",
],
"trade_calendar": [
"exchange",
"cal_date",
"is_open",
"pretrade_date",
],
"stk_limit": [
"ts_code",
"trade_date",
"up_limit",
"down_limit",
],
}
def _ensure_client():
if ts is None:
raise RuntimeError("未安装 tushare请先在环境中安装 tushare 包")
token = get_config().tushare_token or os.getenv("TUSHARE_TOKEN")
if not token:
raise RuntimeError("未配置 TuShare Token请在配置文件或环境变量 TUSHARE_TOKEN 中设置")
if not hasattr(_ensure_client, "_client") or _ensure_client._client is None: # type: ignore[attr-defined]
ts.set_token(token)
_ensure_client._client = ts.pro_api(token) # type: ignore[attr-defined]
LOGGER.info("完成 TuShare 客户端初始化")
return _ensure_client._client # type: ignore[attr-defined]
def _format_date(value: date) -> str:
return value.strftime("%Y%m%d")
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
if df is None or df.empty:
return []
# 对缺失列进行补全,防止写库时缺少绑定参数
reindexed = df.reindex(columns=allowed_cols)
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]:
"""拉取日线行情。"""
client = _ensure_client()
start_date = _format_date(job.start)
end_date = _format_date(job.end)
frames: List[pd.DataFrame] = []
if job.granularity != "daily":
raise ValueError(f"暂不支持的粒度:{job.granularity}")
if job.ts_codes:
for code in job.ts_codes:
LOGGER.info("拉取 %s 的日线行情(%s-%s", code, start_date, end_date)
frames.append(client.daily(ts_code=code, start_date=start_date, end_date=end_date))
else:
LOGGER.info("按全市场拉取日线行情(%s-%s", start_date, end_date)
frames.append(client.daily(start_date=start_date, end_date=end_date))
if not frames:
return []
df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
return _df_to_records(df, _TABLE_COLUMNS["daily"])
def fetch_suspensions(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
client = _ensure_client()
start_date = _format_date(start)
end_date = _format_date(end)
LOGGER.info("拉取停复牌信息(%s-%s", start_date, end_date)
df = client.suspend_d(ts_code=ts_code, start_date=start_date, end_date=end_date)
return _df_to_records(df, _TABLE_COLUMNS["suspend"])
def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]:
client = _ensure_client()
start_date = _format_date(start)
end_date = _format_date(end)
LOGGER.info("拉取交易日历(交易所:%s,区间:%s-%s", exchange, start_date, end_date)
df = client.trade_cal(exchange=exchange, start_date=start_date, end_date=end_date)
return _df_to_records(df, _TABLE_COLUMNS["trade_calendar"])
def fetch_stk_limit(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
client = _ensure_client()
start_date = _format_date(start)
end_date = _format_date(end)
LOGGER.info("拉取涨跌停价格(%s-%s", start_date, end_date)
df = client.stk_limit(ts_code=ts_code, start_date=start_date, end_date=end_date)
return _df_to_records(df, _TABLE_COLUMNS["stk_limit"])
def save_records(table: str, rows: Iterable[Dict]) -> None:
"""将拉取的数据写入 SQLite。"""
items = list(rows)
if not items:
LOGGER.info("%s 没有新增记录,跳过写入", table)
return
schema = _TABLE_SCHEMAS.get(table)
columns = _TABLE_COLUMNS.get(table)
if not schema or not columns:
raise ValueError(f"不支持写入的表:{table}")
placeholders = ",".join([f":{col}" for col in columns])
col_clause = ",".join(columns)
LOGGER.info("%s 写入 %d 条记录", table, len(items))
with db_session() as conn:
conn.executescript(schema)
conn.executemany(
f"INSERT OR REPLACE INTO {table} ({col_clause}) VALUES ({placeholders})",
items,
)
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
"""按任务配置拉取 TuShare 数据。"""
LOGGER.info("启动 TuShare 拉取任务:%s", job.name)
daily_rows = fetch_daily_bars(job)
save_records("daily", daily_rows)
suspend_rows = fetch_suspensions(job.start, job.end)
save_records("suspend", suspend_rows)
calendar_rows = fetch_trade_calendar(job.start, job.end)
save_records("trade_calendar", calendar_rows)
if include_limits:
limit_rows = fetch_stk_limit(job.start, job.end)
save_records("stk_limit", limit_rows)
LOGGER.info("任务 %s 完成", job.name)