update
This commit is contained in:
parent
2dc2753827
commit
36322f66db
@ -23,6 +23,7 @@ Streamlit `自检测试` 页签提供:
|
||||
- TuShare 小范围拉取测试;
|
||||
- 一键开机检查(可自动补数并展示覆盖摘要);
|
||||
- 股票行情可视化(自动加载近段时间价格、成交量,并展示核心指标)。
|
||||
- 开机检查带进度指示与详细日志,便于排查 TuShare 拉取问题。
|
||||
|
||||
`回测与复盘` 页签提供快速回测表单,可调整时间区间、股票池与参数并即时查看回测输出。
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, timedelta
|
||||
from typing import Dict
|
||||
from typing import Callable, Dict
|
||||
|
||||
from app.data.schema import initialize_database
|
||||
from app.ingest.tushare import collect_data_coverage, ensure_data_coverage
|
||||
@ -36,7 +36,11 @@ def _default_window(days: int = 365) -> tuple[date, date]:
|
||||
return start, end
|
||||
|
||||
|
||||
def run_boot_check(days: int = 365, auto_fetch: bool = True) -> CoverageReport:
|
||||
def run_boot_check(
|
||||
days: int = 365,
|
||||
auto_fetch: bool = True,
|
||||
progress_hook: Callable[[str, float], None] | None = None,
|
||||
) -> CoverageReport:
|
||||
"""执行开机自检,必要时自动补数据。"""
|
||||
|
||||
initialize_database()
|
||||
@ -44,7 +48,7 @@ def run_boot_check(days: int = 365, auto_fetch: bool = True) -> CoverageReport:
|
||||
LOGGER.info("开机检查覆盖窗口:%s 至 %s", start, end)
|
||||
|
||||
if auto_fetch:
|
||||
ensure_data_coverage(start, end)
|
||||
ensure_data_coverage(start, end, progress_hook=progress_hook)
|
||||
|
||||
coverage = collect_data_coverage(start, end)
|
||||
|
||||
@ -63,5 +67,7 @@ def run_boot_check(days: int = 365, auto_fetch: bool = True) -> CoverageReport:
|
||||
report.tables["daily"].get("distinct_days"),
|
||||
report.expected_trading_days,
|
||||
)
|
||||
if progress_hook:
|
||||
progress_hook("数据覆盖检查完成", 1.0)
|
||||
|
||||
return report
|
||||
|
||||
@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from typing import Dict, Iterable, List, Optional, Sequence
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@ -16,10 +16,57 @@ except ImportError: # pragma: no cover - 运行时提示
|
||||
|
||||
from app.utils.config import get_config
|
||||
from app.utils.db import db_session
|
||||
|
||||
|
||||
def _existing_date_range(table: str, date_col: str, ts_code: str | None = None) -> Tuple[str | None, str | None]:
|
||||
query = f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d FROM {table}"
|
||||
params: Tuple = ()
|
||||
if ts_code:
|
||||
query += " WHERE ts_code = ?"
|
||||
params = (ts_code,)
|
||||
with db_session(read_only=True) as conn:
|
||||
row = conn.execute(query, params).fetchone()
|
||||
if row is None:
|
||||
return None, None
|
||||
return row["min_d"], row["max_d"]
|
||||
|
||||
|
||||
|
||||
from app.data.schema import initialize_database
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
API_DEFAULT_LIMIT = 5000
|
||||
|
||||
|
||||
def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None = None) -> pd.DataFrame:
|
||||
client = _ensure_client()
|
||||
limit = limit or API_DEFAULT_LIMIT
|
||||
frames: List[pd.DataFrame] = []
|
||||
offset = 0
|
||||
clean_params = {k: v for k, v in params.items() if v is not None}
|
||||
LOGGER.info("开始调用 TuShare 接口:%s,参数=%s,limit=%s", endpoint, clean_params, limit)
|
||||
while True:
|
||||
call = getattr(client, endpoint)
|
||||
try:
|
||||
df = call(limit=limit, offset=offset, **clean_params)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("TuShare 接口调用异常:endpoint=%s offset=%s params=%s", endpoint, offset, clean_params)
|
||||
raise
|
||||
if df is None or df.empty:
|
||||
LOGGER.info("TuShare 返回空数据:endpoint=%s offset=%s", endpoint, offset)
|
||||
break
|
||||
LOGGER.info("TuShare 返回 %s 行:endpoint=%s offset=%s", len(df), endpoint, offset)
|
||||
frames.append(df)
|
||||
if len(df) < limit:
|
||||
break
|
||||
offset += limit
|
||||
if not frames:
|
||||
return pd.DataFrame()
|
||||
merged = pd.concat(frames, ignore_index=True)
|
||||
LOGGER.info("TuShare 调用完成:endpoint=%s 总行数=%s", endpoint, len(merged))
|
||||
return merged
|
||||
|
||||
|
||||
@dataclass
|
||||
class FetchJob:
|
||||
@ -218,6 +265,37 @@ def _format_date(value: date) -> str:
|
||||
return value.strftime("%Y%m%d")
|
||||
|
||||
|
||||
def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str]:
|
||||
start_str = _format_date(start)
|
||||
end_str = _format_date(end)
|
||||
query = (
|
||||
"SELECT cal_date FROM trade_calendar "
|
||||
"WHERE exchange = ? AND cal_date BETWEEN ? AND ? AND is_open = 1 ORDER BY cal_date"
|
||||
)
|
||||
with db_session(read_only=True) as conn:
|
||||
rows = conn.execute(query, (exchange, start_str, end_str)).fetchall()
|
||||
return [row["cal_date"] for row in rows]
|
||||
|
||||
|
||||
|
||||
def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool:
|
||||
min_d, max_d = _existing_date_range(table, date_col, ts_code)
|
||||
if min_d is None or max_d is None:
|
||||
return False
|
||||
start_str = _format_date(start)
|
||||
end_str = _format_date(end)
|
||||
return min_d <= start_str and max_d >= end_str
|
||||
|
||||
|
||||
def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool:
|
||||
min_d, max_d = _existing_date_range(table, date_col, ts_code)
|
||||
if min_d is None or max_d is None:
|
||||
return False
|
||||
start_str = _format_date(start)
|
||||
end_str = _format_date(end)
|
||||
return min_d <= start_str and max_d >= end_str
|
||||
|
||||
|
||||
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
|
||||
if df is None or df.empty:
|
||||
return []
|
||||
@ -301,17 +379,32 @@ def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]:
|
||||
if job.granularity != "daily":
|
||||
raise ValueError(f"暂不支持的粒度:{job.granularity}")
|
||||
|
||||
params = {
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
}
|
||||
|
||||
if job.ts_codes:
|
||||
for code in job.ts_codes:
|
||||
LOGGER.info("拉取 %s 的日线行情(%s-%s)", code, start_date, end_date)
|
||||
frames.append(client.daily(ts_code=code, start_date=start_date, end_date=end_date))
|
||||
df = _fetch_paginated("daily", {**params, "ts_code": code})
|
||||
if not df.empty:
|
||||
frames.append(df)
|
||||
else:
|
||||
LOGGER.info("按全市场拉取日线行情(%s-%s)", start_date, end_date)
|
||||
frames.append(client.daily(start_date=start_date, end_date=end_date))
|
||||
trade_dates = _load_trade_dates(job.start, job.end)
|
||||
if not trade_dates:
|
||||
LOGGER.info("本地交易日历缺失,尝试补全后再拉取日线行情")
|
||||
ensure_trade_calendar(job.start, job.end)
|
||||
trade_dates = _load_trade_dates(job.start, job.end)
|
||||
for trade_date in trade_dates:
|
||||
LOGGER.debug("按交易日拉取日线行情:%s", trade_date)
|
||||
df = _fetch_paginated("daily", {"trade_date": trade_date})
|
||||
if not df.empty:
|
||||
frames.append(df)
|
||||
|
||||
if not frames:
|
||||
return []
|
||||
df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
|
||||
df = pd.concat(frames, ignore_index=True)
|
||||
return _df_to_records(df, _TABLE_COLUMNS["daily"])
|
||||
|
||||
|
||||
@ -320,7 +413,11 @@ def fetch_daily_basic(start: date, end: date, ts_code: Optional[str] = None) ->
|
||||
start_date = _format_date(start)
|
||||
end_date = _format_date(end)
|
||||
LOGGER.info("拉取日线基础指标(%s-%s,股票:%s)", start_date, end_date, ts_code or "全部")
|
||||
df = client.daily_basic(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||||
df = _fetch_paginated("daily_basic", {
|
||||
"ts_code": ts_code,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
})
|
||||
return _df_to_records(df, _TABLE_COLUMNS["daily_basic"])
|
||||
|
||||
|
||||
@ -329,7 +426,11 @@ def fetch_adj_factor(start: date, end: date, ts_code: Optional[str] = None) -> I
|
||||
start_date = _format_date(start)
|
||||
end_date = _format_date(end)
|
||||
LOGGER.info("拉取复权因子(%s-%s,股票:%s)", start_date, end_date, ts_code or "全部")
|
||||
df = client.adj_factor(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||||
df = _fetch_paginated("adj_factor", {
|
||||
"ts_code": ts_code,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
})
|
||||
return _df_to_records(df, _TABLE_COLUMNS["adj_factor"])
|
||||
|
||||
|
||||
@ -338,7 +439,11 @@ def fetch_suspensions(start: date, end: date, ts_code: Optional[str] = None) ->
|
||||
start_date = _format_date(start)
|
||||
end_date = _format_date(end)
|
||||
LOGGER.info("拉取停复牌信息(%s-%s)", start_date, end_date)
|
||||
df = client.suspend_d(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||||
df = _fetch_paginated("suspend_d", {
|
||||
"ts_code": ts_code,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
}, limit=2000)
|
||||
return _df_to_records(df, _TABLE_COLUMNS["suspend"])
|
||||
|
||||
|
||||
@ -356,7 +461,11 @@ def fetch_stk_limit(start: date, end: date, ts_code: Optional[str] = None) -> It
|
||||
start_date = _format_date(start)
|
||||
end_date = _format_date(end)
|
||||
LOGGER.info("拉取涨跌停价格(%s-%s)", start_date, end_date)
|
||||
df = client.stk_limit(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||||
df = _fetch_paginated("stk_limit", {
|
||||
"ts_code": ts_code,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
})
|
||||
return _df_to_records(df, _TABLE_COLUMNS["stk_limit"])
|
||||
|
||||
|
||||
@ -412,41 +521,103 @@ def ensure_data_coverage(
|
||||
ts_codes: Optional[Sequence[str]] = None,
|
||||
include_limits: bool = True,
|
||||
force: bool = False,
|
||||
progress_hook: Callable[[str, float], None] | None = None,
|
||||
) -> None:
|
||||
initialize_database()
|
||||
start_str = _format_date(start)
|
||||
end_str = _format_date(end)
|
||||
|
||||
total_steps = 5 + (1 if include_limits else 0)
|
||||
current_step = 0
|
||||
|
||||
def advance(message: str) -> None:
|
||||
nonlocal current_step
|
||||
current_step += 1
|
||||
progress = min(current_step / total_steps, 1.0)
|
||||
if progress_hook:
|
||||
progress_hook(message, progress)
|
||||
LOGGER.info(message)
|
||||
|
||||
advance("准备股票基础信息与交易日历")
|
||||
ensure_stock_basic()
|
||||
ensure_trade_calendar(start, end)
|
||||
|
||||
codes = tuple(dict.fromkeys(ts_codes)) if ts_codes else tuple()
|
||||
expected_days = _expected_trading_days(start_str, end_str)
|
||||
job = FetchJob("daily_autofill", start=start, end=end, ts_codes=codes)
|
||||
|
||||
if force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days):
|
||||
advance("处理日线行情数据")
|
||||
if codes:
|
||||
pending_codes: List[str] = []
|
||||
for code in codes:
|
||||
if not force and _should_skip_range("daily", "trade_date", start, end, code):
|
||||
LOGGER.info("股票 %s 的日线已覆盖 %s-%s,跳过", code, start_str, end_str)
|
||||
continue
|
||||
pending_codes.append(code)
|
||||
if pending_codes:
|
||||
job = FetchJob("daily_autofill", start=start, end=end, ts_codes=tuple(pending_codes))
|
||||
LOGGER.info("开始拉取日线行情:%s-%s(待补股票 %d 支)", start_str, end_str, len(pending_codes))
|
||||
save_records("daily", fetch_daily_bars(job))
|
||||
else:
|
||||
needs_daily = force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days)
|
||||
if not needs_daily:
|
||||
LOGGER.info("日线数据已覆盖 %s-%s,跳过拉取", start_str, end_str)
|
||||
else:
|
||||
job = FetchJob("daily_autofill", start=start, end=end)
|
||||
LOGGER.info("开始拉取日线行情:%s-%s", start_str, end_str)
|
||||
save_records("daily", fetch_daily_bars(job))
|
||||
|
||||
date_cols = {
|
||||
"daily_basic": "trade_date",
|
||||
"adj_factor": "trade_date",
|
||||
"stk_limit": "trade_date",
|
||||
"suspend": "suspend_date",
|
||||
}
|
||||
|
||||
def _save_with_codes(table: str, fetch_fn) -> None:
|
||||
date_col = date_cols.get(table, "trade_date")
|
||||
if codes:
|
||||
for code in codes:
|
||||
save_records(table, fetch_fn(start, end, ts_code=code))
|
||||
if not force and _should_skip_range(table, date_col, start, end, code):
|
||||
LOGGER.info("表 %s 股票 %s 已覆盖 %s-%s,跳过", table, code, start_str, end_str)
|
||||
continue
|
||||
LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str)
|
||||
try:
|
||||
rows = fetch_fn(start, end, ts_code=code)
|
||||
except Exception:
|
||||
LOGGER.exception("TuShare 拉取失败:table=%s code=%s", table, code)
|
||||
raise
|
||||
save_records(table, rows)
|
||||
else:
|
||||
save_records(table, fetch_fn(start, end))
|
||||
needs_refresh = force
|
||||
if not force:
|
||||
expected = expected_days if table in {"daily_basic", "adj_factor", "stk_limit"} else 0
|
||||
needs_refresh = _range_needs_refresh(table, date_col, start_str, end_str, expected)
|
||||
if not needs_refresh:
|
||||
LOGGER.info("表 %s 已覆盖 %s-%s,跳过", table, start_str, end_str)
|
||||
return
|
||||
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
|
||||
try:
|
||||
rows = fetch_fn(start, end)
|
||||
except Exception:
|
||||
LOGGER.exception("TuShare 拉取失败:table=%s code=全部", table)
|
||||
raise
|
||||
save_records(table, rows)
|
||||
|
||||
if force or _range_needs_refresh("daily_basic", "trade_date", start_str, end_str, expected_days):
|
||||
advance("处理日线基础指标数据")
|
||||
_save_with_codes("daily_basic", fetch_daily_basic)
|
||||
|
||||
if force or _range_needs_refresh("adj_factor", "trade_date", start_str, end_str, expected_days):
|
||||
advance("处理复权因子数据")
|
||||
_save_with_codes("adj_factor", fetch_adj_factor)
|
||||
|
||||
if include_limits and (force or _range_needs_refresh("stk_limit", "trade_date", start_str, end_str, expected_days)):
|
||||
if include_limits:
|
||||
advance("处理涨跌停价格数据")
|
||||
_save_with_codes("stk_limit", fetch_stk_limit)
|
||||
|
||||
if force or _range_needs_refresh("suspend", "suspend_date", start_str, end_str):
|
||||
advance("处理停复牌信息")
|
||||
_save_with_codes("suspend", fetch_suspensions)
|
||||
|
||||
|
||||
if progress_hook:
|
||||
progress_hook("数据覆盖检查完成", 1.0)
|
||||
def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]]:
|
||||
start_str = _format_date(start)
|
||||
end_str = _format_date(end)
|
||||
@ -500,5 +671,11 @@ def collect_data_coverage(start: date, end: date) -> Dict[str, Dict[str, object]
|
||||
|
||||
def run_ingestion(job: FetchJob, include_limits: bool = True) -> None:
|
||||
LOGGER.info("启动 TuShare 拉取任务:%s", job.name)
|
||||
ensure_data_coverage(job.start, job.end, ts_codes=job.ts_codes, include_limits=include_limits, force=True)
|
||||
ensure_data_coverage(
|
||||
job.start,
|
||||
job.end,
|
||||
ts_codes=job.ts_codes,
|
||||
include_limits=include_limits,
|
||||
force=True,
|
||||
)
|
||||
LOGGER.info("任务 %s 完成", job.name)
|
||||
|
||||
@ -156,13 +156,29 @@ def render_tests() -> None:
|
||||
st.divider()
|
||||
days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30))
|
||||
if st.button("执行开机检查"):
|
||||
progress_bar = st.progress(0.0)
|
||||
status_placeholder = st.empty()
|
||||
log_placeholder = st.empty()
|
||||
messages: list[str] = []
|
||||
|
||||
def hook(message: str, value: float) -> None:
|
||||
progress_bar.progress(min(max(value, 0.0), 1.0))
|
||||
status_placeholder.write(message)
|
||||
messages.append(message)
|
||||
|
||||
with st.spinner("正在执行开机检查..."):
|
||||
try:
|
||||
report = run_boot_check(days=days)
|
||||
report = run_boot_check(days=days, progress_hook=hook)
|
||||
st.success("开机检查完成,以下为数据覆盖摘要。")
|
||||
st.json(report.to_dict())
|
||||
if messages:
|
||||
log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
st.error(f"开机检查失败:{exc}")
|
||||
if messages:
|
||||
log_placeholder.markdown("\n".join(f"- {msg}" for msg in messages))
|
||||
finally:
|
||||
progress_bar.progress(1.0)
|
||||
|
||||
st.divider()
|
||||
st.subheader("股票行情可视化")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user