update
This commit is contained in:
parent
774b68de99
commit
6c9c8e3140
@ -1,15 +1,16 @@
|
||||
"""数据覆盖开机检查器。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, timedelta
|
||||
from typing import Callable, Dict
|
||||
|
||||
from app.data.schema import initialize_database
|
||||
from app.ingest.tushare import collect_data_coverage, ensure_data_coverage
|
||||
from app.utils.config import get_config
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -40,6 +41,7 @@ def run_boot_check(
|
||||
days: int = 365,
|
||||
auto_fetch: bool = True,
|
||||
progress_hook: Callable[[str, float], None] | None = None,
|
||||
force_refresh: bool | None = None,
|
||||
) -> CoverageReport:
|
||||
"""执行开机自检,必要时自动补数据。"""
|
||||
|
||||
@ -47,8 +49,17 @@ def run_boot_check(
|
||||
start, end = _default_window(days)
|
||||
LOGGER.info("开机检查覆盖窗口:%s 至 %s", start, end)
|
||||
|
||||
refresh = force_refresh
|
||||
if refresh is None:
|
||||
refresh = get_config().force_refresh
|
||||
|
||||
if auto_fetch:
|
||||
ensure_data_coverage(start, end, progress_hook=progress_hook)
|
||||
ensure_data_coverage(
|
||||
start,
|
||||
end,
|
||||
force=refresh,
|
||||
progress_hook=progress_hook,
|
||||
)
|
||||
|
||||
coverage = collect_data_coverage(start, end)
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""TuShare 数据拉取与数据覆盖检查工具。"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
@ -16,9 +15,21 @@ except ImportError: # pragma: no cover - 运行时提示
|
||||
|
||||
from app.utils.config import get_config
|
||||
from app.utils.db import db_session
|
||||
from app.data.schema import initialize_database
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
|
||||
def _existing_date_range(table: str, date_col: str, ts_code: str | None = None) -> Tuple[str | None, str | None]:
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
API_DEFAULT_LIMIT = 5000
|
||||
LOG_EXTRA = {"stage": "data_ingest"}
|
||||
|
||||
|
||||
def _existing_date_range(
|
||||
table: str,
|
||||
date_col: str,
|
||||
ts_code: str | None = None,
|
||||
) -> Tuple[str | None, str | None]:
|
||||
query = f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d FROM {table}"
|
||||
params: Tuple = ()
|
||||
if ts_code:
|
||||
@ -31,13 +42,11 @@ def _existing_date_range(table: str, date_col: str, ts_code: str | None = None)
|
||||
return row["min_d"], row["max_d"]
|
||||
|
||||
|
||||
|
||||
from app.data.schema import initialize_database
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
API_DEFAULT_LIMIT = 5000
|
||||
LOG_EXTRA = {"stage": "data_ingest"}
|
||||
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
|
||||
if df is None or df.empty:
|
||||
return []
|
||||
reindexed = df.reindex(columns=allowed_cols)
|
||||
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
|
||||
|
||||
|
||||
def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None = None) -> pd.DataFrame:
|
||||
@ -294,13 +303,6 @@ def _format_date(value: date) -> str:
|
||||
return value.strftime("%Y%m%d")
|
||||
|
||||
|
||||
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
|
||||
if df is None or df.empty:
|
||||
return []
|
||||
reindexed = df.reindex(columns=allowed_cols)
|
||||
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
|
||||
|
||||
|
||||
def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str]:
|
||||
start_str = _format_date(start)
|
||||
end_str = _format_date(end)
|
||||
@ -405,34 +407,66 @@ def fetch_stock_basic(exchange: Optional[str] = None, list_status: str = "L") ->
|
||||
return _df_to_records(df, _TABLE_COLUMNS["stock_basic"])
|
||||
|
||||
|
||||
def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]:
|
||||
def fetch_daily_bars(job: FetchJob, skip_existing: bool = True) -> Iterable[Dict]:
|
||||
client = _ensure_client()
|
||||
start_date = _format_date(job.start)
|
||||
end_date = _format_date(job.end)
|
||||
frames: List[pd.DataFrame] = []
|
||||
|
||||
if job.granularity != "daily":
|
||||
raise ValueError(f"暂不支持的粒度:{job.granularity}")
|
||||
|
||||
params = {
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
}
|
||||
trade_dates = _load_trade_dates(job.start, job.end)
|
||||
if not trade_dates:
|
||||
LOGGER.info("本地交易日历缺失,尝试补全后再拉取日线行情", extra=LOG_EXTRA)
|
||||
ensure_trade_calendar(job.start, job.end)
|
||||
trade_dates = _load_trade_dates(job.start, job.end)
|
||||
|
||||
if job.ts_codes:
|
||||
for code in job.ts_codes:
|
||||
LOGGER.info("拉取 %s 的日线行情(%s-%s)", code, start_date, end_date)
|
||||
df = _fetch_paginated("daily", {**params, "ts_code": code})
|
||||
for trade_date in trade_dates:
|
||||
if skip_existing and _record_exists("daily", "trade_date", trade_date, code):
|
||||
LOGGER.debug(
|
||||
"日线数据已存在,跳过 %s %s",
|
||||
code,
|
||||
trade_date,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
LOGGER.debug(
|
||||
"按交易日拉取日线行情:code=%s trade_date=%s",
|
||||
code,
|
||||
trade_date,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
LOGGER.info(
|
||||
"交易日拉取请求:endpoint=daily code=%s trade_date=%s",
|
||||
code,
|
||||
trade_date,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
df = _fetch_paginated(
|
||||
"daily",
|
||||
{
|
||||
"trade_date": trade_date,
|
||||
"ts_code": code,
|
||||
},
|
||||
)
|
||||
if not df.empty:
|
||||
frames.append(df)
|
||||
else:
|
||||
trade_dates = _load_trade_dates(job.start, job.end)
|
||||
if not trade_dates:
|
||||
LOGGER.info("本地交易日历缺失,尝试补全后再拉取日线行情")
|
||||
ensure_trade_calendar(job.start, job.end)
|
||||
trade_dates = _load_trade_dates(job.start, job.end)
|
||||
for trade_date in trade_dates:
|
||||
LOGGER.debug("按交易日拉取日线行情:%s", trade_date)
|
||||
if skip_existing and _record_exists("daily", "trade_date", trade_date):
|
||||
LOGGER.debug(
|
||||
"日线数据已存在,跳过交易日 %s",
|
||||
trade_date,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
LOGGER.debug("按交易日拉取日线行情:%s", trade_date, extra=LOG_EXTRA)
|
||||
LOGGER.info(
|
||||
"交易日拉取请求:endpoint=daily trade_date=%s",
|
||||
trade_date,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
df = _fetch_paginated("daily", {"trade_date": trade_date})
|
||||
if not df.empty:
|
||||
frames.append(df)
|
||||
@ -460,33 +494,25 @@ def fetch_daily_basic(
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
if ts_code:
|
||||
df = _fetch_paginated(
|
||||
"daily_basic",
|
||||
{
|
||||
"ts_code": ts_code,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
},
|
||||
)
|
||||
return _df_to_records(df, _TABLE_COLUMNS["daily_basic"])
|
||||
|
||||
trade_dates = _load_trade_dates(start, end)
|
||||
frames: List[pd.DataFrame] = []
|
||||
for trade_date in trade_dates:
|
||||
if skip_existing and _record_exists("daily_basic", "trade_date", trade_date):
|
||||
if skip_existing and _record_exists("daily_basic", "trade_date", trade_date, ts_code):
|
||||
LOGGER.info(
|
||||
"日线基础指标已存在,跳过交易日 %s",
|
||||
trade_date,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
LOGGER.debug(
|
||||
"按交易日拉取日线基础指标:%s",
|
||||
trade_date,
|
||||
params = {"trade_date": trade_date}
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
LOGGER.info(
|
||||
"交易日拉取请求:endpoint=daily_basic params=%s",
|
||||
params,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
df = _fetch_paginated("daily_basic", {"trade_date": trade_date})
|
||||
df = _fetch_paginated("daily_basic", params)
|
||||
if not df.empty:
|
||||
frames.append(df)
|
||||
|
||||
@ -528,7 +554,7 @@ def fetch_adj_factor(
|
||||
params = {"trade_date": trade_date}
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
LOGGER.debug("按交易日拉取复权因子:%s", params, extra=LOG_EXTRA)
|
||||
LOGGER.info("交易日拉取请求:endpoint=adj_factor params=%s", params, extra=LOG_EXTRA)
|
||||
df = _fetch_paginated("adj_factor", params)
|
||||
if not df.empty:
|
||||
frames.append(df)
|
||||
@ -540,17 +566,40 @@ def fetch_adj_factor(
|
||||
return _df_to_records(merged, _TABLE_COLUMNS["adj_factor"])
|
||||
|
||||
|
||||
def fetch_suspensions(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
|
||||
def fetch_suspensions(
|
||||
start: date,
|
||||
end: date,
|
||||
ts_code: Optional[str] = None,
|
||||
skip_existing: bool = True,
|
||||
) -> Iterable[Dict]:
|
||||
client = _ensure_client()
|
||||
start_date = _format_date(start)
|
||||
end_date = _format_date(end)
|
||||
LOGGER.info("拉取停复牌信息(%s-%s)", start_date, end_date)
|
||||
df = _fetch_paginated("suspend_d", {
|
||||
"ts_code": ts_code,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
}, limit=2000)
|
||||
return _df_to_records(df, _TABLE_COLUMNS["suspend"])
|
||||
LOGGER.info("拉取停复牌信息(%s-%s)", start_date, end_date, extra=LOG_EXTRA)
|
||||
trade_dates = _load_trade_dates(start, end)
|
||||
frames: List[pd.DataFrame] = []
|
||||
for trade_date in trade_dates:
|
||||
if skip_existing and _record_exists("suspend", "suspend_date", trade_date, ts_code):
|
||||
LOGGER.debug(
|
||||
"停复牌信息已存在,跳过 %s %s",
|
||||
ts_code or "ALL",
|
||||
trade_date,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
params = {"trade_date": trade_date}
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
LOGGER.info("交易日拉取请求:endpoint=suspend_d params=%s", params, extra=LOG_EXTRA)
|
||||
df = _fetch_paginated("suspend_d", params, limit=2000)
|
||||
if not df.empty:
|
||||
frames.append(df)
|
||||
|
||||
if not frames:
|
||||
return []
|
||||
|
||||
merged = pd.concat(frames, ignore_index=True)
|
||||
return _df_to_records(merged, _TABLE_COLUMNS["suspend"])
|
||||
|
||||
|
||||
def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]:
|
||||
@ -562,17 +611,40 @@ def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Itera
|
||||
return _df_to_records(df, _TABLE_COLUMNS["trade_calendar"])
|
||||
|
||||
|
||||
def fetch_stk_limit(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
|
||||
def fetch_stk_limit(
|
||||
start: date,
|
||||
end: date,
|
||||
ts_code: Optional[str] = None,
|
||||
skip_existing: bool = True,
|
||||
) -> Iterable[Dict]:
|
||||
client = _ensure_client()
|
||||
start_date = _format_date(start)
|
||||
end_date = _format_date(end)
|
||||
LOGGER.info("拉取涨跌停价格(%s-%s)", start_date, end_date)
|
||||
df = _fetch_paginated("stk_limit", {
|
||||
"ts_code": ts_code,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
})
|
||||
return _df_to_records(df, _TABLE_COLUMNS["stk_limit"])
|
||||
LOGGER.info("拉取涨跌停价格(%s-%s)", start_date, end_date, extra=LOG_EXTRA)
|
||||
trade_dates = _load_trade_dates(start, end)
|
||||
frames: List[pd.DataFrame] = []
|
||||
for trade_date in trade_dates:
|
||||
if skip_existing and _record_exists("stk_limit", "trade_date", trade_date, ts_code):
|
||||
LOGGER.debug(
|
||||
"涨跌停数据已存在,跳过 %s %s",
|
||||
ts_code or "ALL",
|
||||
trade_date,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
params = {"trade_date": trade_date}
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
LOGGER.info("交易日拉取请求:endpoint=stk_limit params=%s", params, extra=LOG_EXTRA)
|
||||
df = _fetch_paginated("stk_limit", params)
|
||||
if not df.empty:
|
||||
frames.append(df)
|
||||
|
||||
if not frames:
|
||||
return []
|
||||
|
||||
merged = pd.concat(frames, ignore_index=True)
|
||||
return _df_to_records(merged, _TABLE_COLUMNS["stk_limit"])
|
||||
|
||||
|
||||
def save_records(table: str, rows: Iterable[Dict]) -> None:
|
||||
@ -662,7 +734,7 @@ def ensure_data_coverage(
|
||||
if pending_codes:
|
||||
job = FetchJob("daily_autofill", start=start, end=end, ts_codes=tuple(pending_codes))
|
||||
LOGGER.info("开始拉取日线行情:%s-%s(待补股票 %d 支)", start_str, end_str, len(pending_codes))
|
||||
save_records("daily", fetch_daily_bars(job))
|
||||
save_records("daily", fetch_daily_bars(job, skip_existing=not force))
|
||||
else:
|
||||
needs_daily = force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days)
|
||||
if not needs_daily:
|
||||
@ -670,7 +742,7 @@ def ensure_data_coverage(
|
||||
else:
|
||||
job = FetchJob("daily_autofill", start=start, end=end)
|
||||
LOGGER.info("开始拉取日线行情:%s-%s", start_str, end_str)
|
||||
save_records("daily", fetch_daily_bars(job))
|
||||
save_records("daily", fetch_daily_bars(job, skip_existing=not force))
|
||||
|
||||
date_cols = {
|
||||
"daily_basic": "trade_date",
|
||||
@ -689,7 +761,7 @@ def ensure_data_coverage(
|
||||
LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str)
|
||||
try:
|
||||
kwargs = {"ts_code": code}
|
||||
if fetch_fn in (fetch_daily_basic, fetch_adj_factor):
|
||||
if fetch_fn in (fetch_daily_basic, fetch_adj_factor, fetch_suspensions, fetch_stk_limit):
|
||||
kwargs["skip_existing"] = not force
|
||||
rows = fetch_fn(start, end, **kwargs)
|
||||
except Exception:
|
||||
@ -707,7 +779,7 @@ def ensure_data_coverage(
|
||||
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
|
||||
try:
|
||||
kwargs = {}
|
||||
if fetch_fn in (fetch_daily_basic, fetch_adj_factor):
|
||||
if fetch_fn in (fetch_daily_basic, fetch_adj_factor, fetch_suspensions, fetch_stk_limit):
|
||||
kwargs["skip_existing"] = not force
|
||||
rows = fetch_fn(start, end, **kwargs)
|
||||
except Exception:
|
||||
|
||||
@ -114,9 +114,10 @@ def render_settings() -> None:
|
||||
st.header("数据与设置")
|
||||
cfg = get_config()
|
||||
token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password")
|
||||
if st.button("保存 Token"):
|
||||
|
||||
if st.button("保存设置"):
|
||||
cfg.tushare_token = token.strip() or None
|
||||
st.success("TuShare Token 已更新,仅保存在当前会话。")
|
||||
st.success("设置已保存,仅在当前会话生效。")
|
||||
|
||||
st.write("新闻源开关与数据库备份将在此配置。")
|
||||
|
||||
@ -155,6 +156,15 @@ def render_tests() -> None:
|
||||
|
||||
st.divider()
|
||||
days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30))
|
||||
cfg = get_config()
|
||||
force_refresh = st.checkbox(
|
||||
"强制刷新数据(关闭增量跳过)",
|
||||
value=cfg.force_refresh,
|
||||
help="勾选后将重新拉取所选区间全部数据",
|
||||
)
|
||||
if force_refresh != cfg.force_refresh:
|
||||
cfg.force_refresh = force_refresh
|
||||
|
||||
if st.button("执行开机检查"):
|
||||
progress_bar = st.progress(0.0)
|
||||
status_placeholder = st.empty()
|
||||
@ -168,7 +178,11 @@ def render_tests() -> None:
|
||||
|
||||
with st.spinner("正在执行开机检查..."):
|
||||
try:
|
||||
report = run_boot_check(days=days, progress_hook=hook)
|
||||
report = run_boot_check(
|
||||
days=days,
|
||||
progress_hook=hook,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
st.success("开机检查完成,以下为数据覆盖摘要。")
|
||||
st.json(report.to_dict())
|
||||
if messages:
|
||||
|
||||
@ -54,6 +54,7 @@ class AppConfig:
|
||||
decision_method: str = "nash"
|
||||
data_paths: DataPaths = field(default_factory=DataPaths)
|
||||
agent_weights: AgentWeights = field(default_factory=AgentWeights)
|
||||
force_refresh: bool = False
|
||||
|
||||
|
||||
CONFIG = AppConfig()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user