update
This commit is contained in:
parent
774b68de99
commit
6c9c8e3140
@ -1,15 +1,16 @@
|
|||||||
"""数据覆盖开机检查器。"""
|
"""数据覆盖开机检查器。"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import date, timedelta
|
from datetime import date, timedelta
|
||||||
from typing import Callable, Dict
|
from typing import Callable, Dict
|
||||||
|
|
||||||
from app.data.schema import initialize_database
|
from app.data.schema import initialize_database
|
||||||
from app.ingest.tushare import collect_data_coverage, ensure_data_coverage
|
from app.ingest.tushare import collect_data_coverage, ensure_data_coverage
|
||||||
|
from app.utils.config import get_config
|
||||||
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
LOGGER = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -40,6 +41,7 @@ def run_boot_check(
|
|||||||
days: int = 365,
|
days: int = 365,
|
||||||
auto_fetch: bool = True,
|
auto_fetch: bool = True,
|
||||||
progress_hook: Callable[[str, float], None] | None = None,
|
progress_hook: Callable[[str, float], None] | None = None,
|
||||||
|
force_refresh: bool | None = None,
|
||||||
) -> CoverageReport:
|
) -> CoverageReport:
|
||||||
"""执行开机自检,必要时自动补数据。"""
|
"""执行开机自检,必要时自动补数据。"""
|
||||||
|
|
||||||
@ -47,8 +49,17 @@ def run_boot_check(
|
|||||||
start, end = _default_window(days)
|
start, end = _default_window(days)
|
||||||
LOGGER.info("开机检查覆盖窗口:%s 至 %s", start, end)
|
LOGGER.info("开机检查覆盖窗口:%s 至 %s", start, end)
|
||||||
|
|
||||||
|
refresh = force_refresh
|
||||||
|
if refresh is None:
|
||||||
|
refresh = get_config().force_refresh
|
||||||
|
|
||||||
if auto_fetch:
|
if auto_fetch:
|
||||||
ensure_data_coverage(start, end, progress_hook=progress_hook)
|
ensure_data_coverage(
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
force=refresh,
|
||||||
|
progress_hook=progress_hook,
|
||||||
|
)
|
||||||
|
|
||||||
coverage = collect_data_coverage(start, end)
|
coverage = collect_data_coverage(start, end)
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
"""TuShare 数据拉取与数据覆盖检查工具。"""
|
"""TuShare 数据拉取与数据覆盖检查工具。"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import date
|
from datetime import date
|
||||||
@ -16,9 +15,21 @@ except ImportError: # pragma: no cover - 运行时提示
|
|||||||
|
|
||||||
from app.utils.config import get_config
|
from app.utils.config import get_config
|
||||||
from app.utils.db import db_session
|
from app.utils.db import db_session
|
||||||
|
from app.data.schema import initialize_database
|
||||||
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
def _existing_date_range(table: str, date_col: str, ts_code: str | None = None) -> Tuple[str | None, str | None]:
|
LOGGER = get_logger(__name__)
|
||||||
|
|
||||||
|
API_DEFAULT_LIMIT = 5000
|
||||||
|
LOG_EXTRA = {"stage": "data_ingest"}
|
||||||
|
|
||||||
|
|
||||||
|
def _existing_date_range(
|
||||||
|
table: str,
|
||||||
|
date_col: str,
|
||||||
|
ts_code: str | None = None,
|
||||||
|
) -> Tuple[str | None, str | None]:
|
||||||
query = f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d FROM {table}"
|
query = f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d FROM {table}"
|
||||||
params: Tuple = ()
|
params: Tuple = ()
|
||||||
if ts_code:
|
if ts_code:
|
||||||
@ -31,13 +42,11 @@ def _existing_date_range(table: str, date_col: str, ts_code: str | None = None)
|
|||||||
return row["min_d"], row["max_d"]
|
return row["min_d"], row["max_d"]
|
||||||
|
|
||||||
|
|
||||||
|
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
|
||||||
from app.data.schema import initialize_database
|
if df is None or df.empty:
|
||||||
|
return []
|
||||||
LOGGER = logging.getLogger(__name__)
|
reindexed = df.reindex(columns=allowed_cols)
|
||||||
|
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
|
||||||
API_DEFAULT_LIMIT = 5000
|
|
||||||
LOG_EXTRA = {"stage": "data_ingest"}
|
|
||||||
|
|
||||||
|
|
||||||
def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None = None) -> pd.DataFrame:
|
def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None = None) -> pd.DataFrame:
|
||||||
@ -294,13 +303,6 @@ def _format_date(value: date) -> str:
|
|||||||
return value.strftime("%Y%m%d")
|
return value.strftime("%Y%m%d")
|
||||||
|
|
||||||
|
|
||||||
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
|
|
||||||
if df is None or df.empty:
|
|
||||||
return []
|
|
||||||
reindexed = df.reindex(columns=allowed_cols)
|
|
||||||
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
|
|
||||||
|
|
||||||
|
|
||||||
def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str]:
|
def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str]:
|
||||||
start_str = _format_date(start)
|
start_str = _format_date(start)
|
||||||
end_str = _format_date(end)
|
end_str = _format_date(end)
|
||||||
@ -405,34 +407,66 @@ def fetch_stock_basic(exchange: Optional[str] = None, list_status: str = "L") ->
|
|||||||
return _df_to_records(df, _TABLE_COLUMNS["stock_basic"])
|
return _df_to_records(df, _TABLE_COLUMNS["stock_basic"])
|
||||||
|
|
||||||
|
|
||||||
def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]:
|
def fetch_daily_bars(job: FetchJob, skip_existing: bool = True) -> Iterable[Dict]:
|
||||||
client = _ensure_client()
|
client = _ensure_client()
|
||||||
start_date = _format_date(job.start)
|
|
||||||
end_date = _format_date(job.end)
|
|
||||||
frames: List[pd.DataFrame] = []
|
frames: List[pd.DataFrame] = []
|
||||||
|
|
||||||
if job.granularity != "daily":
|
if job.granularity != "daily":
|
||||||
raise ValueError(f"暂不支持的粒度:{job.granularity}")
|
raise ValueError(f"暂不支持的粒度:{job.granularity}")
|
||||||
|
|
||||||
params = {
|
trade_dates = _load_trade_dates(job.start, job.end)
|
||||||
"start_date": start_date,
|
if not trade_dates:
|
||||||
"end_date": end_date,
|
LOGGER.info("本地交易日历缺失,尝试补全后再拉取日线行情", extra=LOG_EXTRA)
|
||||||
}
|
ensure_trade_calendar(job.start, job.end)
|
||||||
|
trade_dates = _load_trade_dates(job.start, job.end)
|
||||||
|
|
||||||
if job.ts_codes:
|
if job.ts_codes:
|
||||||
for code in job.ts_codes:
|
for code in job.ts_codes:
|
||||||
LOGGER.info("拉取 %s 的日线行情(%s-%s)", code, start_date, end_date)
|
for trade_date in trade_dates:
|
||||||
df = _fetch_paginated("daily", {**params, "ts_code": code})
|
if skip_existing and _record_exists("daily", "trade_date", trade_date, code):
|
||||||
if not df.empty:
|
LOGGER.debug(
|
||||||
frames.append(df)
|
"日线数据已存在,跳过 %s %s",
|
||||||
|
code,
|
||||||
|
trade_date,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
LOGGER.debug(
|
||||||
|
"按交易日拉取日线行情:code=%s trade_date=%s",
|
||||||
|
code,
|
||||||
|
trade_date,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
LOGGER.info(
|
||||||
|
"交易日拉取请求:endpoint=daily code=%s trade_date=%s",
|
||||||
|
code,
|
||||||
|
trade_date,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
df = _fetch_paginated(
|
||||||
|
"daily",
|
||||||
|
{
|
||||||
|
"trade_date": trade_date,
|
||||||
|
"ts_code": code,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if not df.empty:
|
||||||
|
frames.append(df)
|
||||||
else:
|
else:
|
||||||
trade_dates = _load_trade_dates(job.start, job.end)
|
|
||||||
if not trade_dates:
|
|
||||||
LOGGER.info("本地交易日历缺失,尝试补全后再拉取日线行情")
|
|
||||||
ensure_trade_calendar(job.start, job.end)
|
|
||||||
trade_dates = _load_trade_dates(job.start, job.end)
|
|
||||||
for trade_date in trade_dates:
|
for trade_date in trade_dates:
|
||||||
LOGGER.debug("按交易日拉取日线行情:%s", trade_date)
|
if skip_existing and _record_exists("daily", "trade_date", trade_date):
|
||||||
|
LOGGER.debug(
|
||||||
|
"日线数据已存在,跳过交易日 %s",
|
||||||
|
trade_date,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
LOGGER.debug("按交易日拉取日线行情:%s", trade_date, extra=LOG_EXTRA)
|
||||||
|
LOGGER.info(
|
||||||
|
"交易日拉取请求:endpoint=daily trade_date=%s",
|
||||||
|
trade_date,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
df = _fetch_paginated("daily", {"trade_date": trade_date})
|
df = _fetch_paginated("daily", {"trade_date": trade_date})
|
||||||
if not df.empty:
|
if not df.empty:
|
||||||
frames.append(df)
|
frames.append(df)
|
||||||
@ -460,33 +494,25 @@ def fetch_daily_basic(
|
|||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
|
|
||||||
if ts_code:
|
|
||||||
df = _fetch_paginated(
|
|
||||||
"daily_basic",
|
|
||||||
{
|
|
||||||
"ts_code": ts_code,
|
|
||||||
"start_date": start_date,
|
|
||||||
"end_date": end_date,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return _df_to_records(df, _TABLE_COLUMNS["daily_basic"])
|
|
||||||
|
|
||||||
trade_dates = _load_trade_dates(start, end)
|
trade_dates = _load_trade_dates(start, end)
|
||||||
frames: List[pd.DataFrame] = []
|
frames: List[pd.DataFrame] = []
|
||||||
for trade_date in trade_dates:
|
for trade_date in trade_dates:
|
||||||
if skip_existing and _record_exists("daily_basic", "trade_date", trade_date):
|
if skip_existing and _record_exists("daily_basic", "trade_date", trade_date, ts_code):
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
"日线基础指标已存在,跳过交易日 %s",
|
"日线基础指标已存在,跳过交易日 %s",
|
||||||
trade_date,
|
trade_date,
|
||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
LOGGER.debug(
|
params = {"trade_date": trade_date}
|
||||||
"按交易日拉取日线基础指标:%s",
|
if ts_code:
|
||||||
trade_date,
|
params["ts_code"] = ts_code
|
||||||
|
LOGGER.info(
|
||||||
|
"交易日拉取请求:endpoint=daily_basic params=%s",
|
||||||
|
params,
|
||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
df = _fetch_paginated("daily_basic", {"trade_date": trade_date})
|
df = _fetch_paginated("daily_basic", params)
|
||||||
if not df.empty:
|
if not df.empty:
|
||||||
frames.append(df)
|
frames.append(df)
|
||||||
|
|
||||||
@ -528,7 +554,7 @@ def fetch_adj_factor(
|
|||||||
params = {"trade_date": trade_date}
|
params = {"trade_date": trade_date}
|
||||||
if ts_code:
|
if ts_code:
|
||||||
params["ts_code"] = ts_code
|
params["ts_code"] = ts_code
|
||||||
LOGGER.debug("按交易日拉取复权因子:%s", params, extra=LOG_EXTRA)
|
LOGGER.info("交易日拉取请求:endpoint=adj_factor params=%s", params, extra=LOG_EXTRA)
|
||||||
df = _fetch_paginated("adj_factor", params)
|
df = _fetch_paginated("adj_factor", params)
|
||||||
if not df.empty:
|
if not df.empty:
|
||||||
frames.append(df)
|
frames.append(df)
|
||||||
@ -540,17 +566,40 @@ def fetch_adj_factor(
|
|||||||
return _df_to_records(merged, _TABLE_COLUMNS["adj_factor"])
|
return _df_to_records(merged, _TABLE_COLUMNS["adj_factor"])
|
||||||
|
|
||||||
|
|
||||||
def fetch_suspensions(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
|
def fetch_suspensions(
|
||||||
|
start: date,
|
||||||
|
end: date,
|
||||||
|
ts_code: Optional[str] = None,
|
||||||
|
skip_existing: bool = True,
|
||||||
|
) -> Iterable[Dict]:
|
||||||
client = _ensure_client()
|
client = _ensure_client()
|
||||||
start_date = _format_date(start)
|
start_date = _format_date(start)
|
||||||
end_date = _format_date(end)
|
end_date = _format_date(end)
|
||||||
LOGGER.info("拉取停复牌信息(%s-%s)", start_date, end_date)
|
LOGGER.info("拉取停复牌信息(%s-%s)", start_date, end_date, extra=LOG_EXTRA)
|
||||||
df = _fetch_paginated("suspend_d", {
|
trade_dates = _load_trade_dates(start, end)
|
||||||
"ts_code": ts_code,
|
frames: List[pd.DataFrame] = []
|
||||||
"start_date": start_date,
|
for trade_date in trade_dates:
|
||||||
"end_date": end_date,
|
if skip_existing and _record_exists("suspend", "suspend_date", trade_date, ts_code):
|
||||||
}, limit=2000)
|
LOGGER.debug(
|
||||||
return _df_to_records(df, _TABLE_COLUMNS["suspend"])
|
"停复牌信息已存在,跳过 %s %s",
|
||||||
|
ts_code or "ALL",
|
||||||
|
trade_date,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
params = {"trade_date": trade_date}
|
||||||
|
if ts_code:
|
||||||
|
params["ts_code"] = ts_code
|
||||||
|
LOGGER.info("交易日拉取请求:endpoint=suspend_d params=%s", params, extra=LOG_EXTRA)
|
||||||
|
df = _fetch_paginated("suspend_d", params, limit=2000)
|
||||||
|
if not df.empty:
|
||||||
|
frames.append(df)
|
||||||
|
|
||||||
|
if not frames:
|
||||||
|
return []
|
||||||
|
|
||||||
|
merged = pd.concat(frames, ignore_index=True)
|
||||||
|
return _df_to_records(merged, _TABLE_COLUMNS["suspend"])
|
||||||
|
|
||||||
|
|
||||||
def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]:
|
def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Iterable[Dict]:
|
||||||
@ -562,17 +611,40 @@ def fetch_trade_calendar(start: date, end: date, exchange: str = "SSE") -> Itera
|
|||||||
return _df_to_records(df, _TABLE_COLUMNS["trade_calendar"])
|
return _df_to_records(df, _TABLE_COLUMNS["trade_calendar"])
|
||||||
|
|
||||||
|
|
||||||
def fetch_stk_limit(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
|
def fetch_stk_limit(
|
||||||
|
start: date,
|
||||||
|
end: date,
|
||||||
|
ts_code: Optional[str] = None,
|
||||||
|
skip_existing: bool = True,
|
||||||
|
) -> Iterable[Dict]:
|
||||||
client = _ensure_client()
|
client = _ensure_client()
|
||||||
start_date = _format_date(start)
|
start_date = _format_date(start)
|
||||||
end_date = _format_date(end)
|
end_date = _format_date(end)
|
||||||
LOGGER.info("拉取涨跌停价格(%s-%s)", start_date, end_date)
|
LOGGER.info("拉取涨跌停价格(%s-%s)", start_date, end_date, extra=LOG_EXTRA)
|
||||||
df = _fetch_paginated("stk_limit", {
|
trade_dates = _load_trade_dates(start, end)
|
||||||
"ts_code": ts_code,
|
frames: List[pd.DataFrame] = []
|
||||||
"start_date": start_date,
|
for trade_date in trade_dates:
|
||||||
"end_date": end_date,
|
if skip_existing and _record_exists("stk_limit", "trade_date", trade_date, ts_code):
|
||||||
})
|
LOGGER.debug(
|
||||||
return _df_to_records(df, _TABLE_COLUMNS["stk_limit"])
|
"涨跌停数据已存在,跳过 %s %s",
|
||||||
|
ts_code or "ALL",
|
||||||
|
trade_date,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
params = {"trade_date": trade_date}
|
||||||
|
if ts_code:
|
||||||
|
params["ts_code"] = ts_code
|
||||||
|
LOGGER.info("交易日拉取请求:endpoint=stk_limit params=%s", params, extra=LOG_EXTRA)
|
||||||
|
df = _fetch_paginated("stk_limit", params)
|
||||||
|
if not df.empty:
|
||||||
|
frames.append(df)
|
||||||
|
|
||||||
|
if not frames:
|
||||||
|
return []
|
||||||
|
|
||||||
|
merged = pd.concat(frames, ignore_index=True)
|
||||||
|
return _df_to_records(merged, _TABLE_COLUMNS["stk_limit"])
|
||||||
|
|
||||||
|
|
||||||
def save_records(table: str, rows: Iterable[Dict]) -> None:
|
def save_records(table: str, rows: Iterable[Dict]) -> None:
|
||||||
@ -662,7 +734,7 @@ def ensure_data_coverage(
|
|||||||
if pending_codes:
|
if pending_codes:
|
||||||
job = FetchJob("daily_autofill", start=start, end=end, ts_codes=tuple(pending_codes))
|
job = FetchJob("daily_autofill", start=start, end=end, ts_codes=tuple(pending_codes))
|
||||||
LOGGER.info("开始拉取日线行情:%s-%s(待补股票 %d 支)", start_str, end_str, len(pending_codes))
|
LOGGER.info("开始拉取日线行情:%s-%s(待补股票 %d 支)", start_str, end_str, len(pending_codes))
|
||||||
save_records("daily", fetch_daily_bars(job))
|
save_records("daily", fetch_daily_bars(job, skip_existing=not force))
|
||||||
else:
|
else:
|
||||||
needs_daily = force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days)
|
needs_daily = force or _range_needs_refresh("daily", "trade_date", start_str, end_str, expected_days)
|
||||||
if not needs_daily:
|
if not needs_daily:
|
||||||
@ -670,7 +742,7 @@ def ensure_data_coverage(
|
|||||||
else:
|
else:
|
||||||
job = FetchJob("daily_autofill", start=start, end=end)
|
job = FetchJob("daily_autofill", start=start, end=end)
|
||||||
LOGGER.info("开始拉取日线行情:%s-%s", start_str, end_str)
|
LOGGER.info("开始拉取日线行情:%s-%s", start_str, end_str)
|
||||||
save_records("daily", fetch_daily_bars(job))
|
save_records("daily", fetch_daily_bars(job, skip_existing=not force))
|
||||||
|
|
||||||
date_cols = {
|
date_cols = {
|
||||||
"daily_basic": "trade_date",
|
"daily_basic": "trade_date",
|
||||||
@ -689,7 +761,7 @@ def ensure_data_coverage(
|
|||||||
LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str)
|
LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str)
|
||||||
try:
|
try:
|
||||||
kwargs = {"ts_code": code}
|
kwargs = {"ts_code": code}
|
||||||
if fetch_fn in (fetch_daily_basic, fetch_adj_factor):
|
if fetch_fn in (fetch_daily_basic, fetch_adj_factor, fetch_suspensions, fetch_stk_limit):
|
||||||
kwargs["skip_existing"] = not force
|
kwargs["skip_existing"] = not force
|
||||||
rows = fetch_fn(start, end, **kwargs)
|
rows = fetch_fn(start, end, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -707,7 +779,7 @@ def ensure_data_coverage(
|
|||||||
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
|
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
|
||||||
try:
|
try:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if fetch_fn in (fetch_daily_basic, fetch_adj_factor):
|
if fetch_fn in (fetch_daily_basic, fetch_adj_factor, fetch_suspensions, fetch_stk_limit):
|
||||||
kwargs["skip_existing"] = not force
|
kwargs["skip_existing"] = not force
|
||||||
rows = fetch_fn(start, end, **kwargs)
|
rows = fetch_fn(start, end, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@ -114,9 +114,10 @@ def render_settings() -> None:
|
|||||||
st.header("数据与设置")
|
st.header("数据与设置")
|
||||||
cfg = get_config()
|
cfg = get_config()
|
||||||
token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password")
|
token = st.text_input("TuShare Token", value=cfg.tushare_token or "", type="password")
|
||||||
if st.button("保存 Token"):
|
|
||||||
|
if st.button("保存设置"):
|
||||||
cfg.tushare_token = token.strip() or None
|
cfg.tushare_token = token.strip() or None
|
||||||
st.success("TuShare Token 已更新,仅保存在当前会话。")
|
st.success("设置已保存,仅在当前会话生效。")
|
||||||
|
|
||||||
st.write("新闻源开关与数据库备份将在此配置。")
|
st.write("新闻源开关与数据库备份将在此配置。")
|
||||||
|
|
||||||
@ -155,6 +156,15 @@ def render_tests() -> None:
|
|||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30))
|
days = int(st.number_input("检查窗口(天数)", min_value=30, max_value=1095, value=365, step=30))
|
||||||
|
cfg = get_config()
|
||||||
|
force_refresh = st.checkbox(
|
||||||
|
"强制刷新数据(关闭增量跳过)",
|
||||||
|
value=cfg.force_refresh,
|
||||||
|
help="勾选后将重新拉取所选区间全部数据",
|
||||||
|
)
|
||||||
|
if force_refresh != cfg.force_refresh:
|
||||||
|
cfg.force_refresh = force_refresh
|
||||||
|
|
||||||
if st.button("执行开机检查"):
|
if st.button("执行开机检查"):
|
||||||
progress_bar = st.progress(0.0)
|
progress_bar = st.progress(0.0)
|
||||||
status_placeholder = st.empty()
|
status_placeholder = st.empty()
|
||||||
@ -168,7 +178,11 @@ def render_tests() -> None:
|
|||||||
|
|
||||||
with st.spinner("正在执行开机检查..."):
|
with st.spinner("正在执行开机检查..."):
|
||||||
try:
|
try:
|
||||||
report = run_boot_check(days=days, progress_hook=hook)
|
report = run_boot_check(
|
||||||
|
days=days,
|
||||||
|
progress_hook=hook,
|
||||||
|
force_refresh=force_refresh,
|
||||||
|
)
|
||||||
st.success("开机检查完成,以下为数据覆盖摘要。")
|
st.success("开机检查完成,以下为数据覆盖摘要。")
|
||||||
st.json(report.to_dict())
|
st.json(report.to_dict())
|
||||||
if messages:
|
if messages:
|
||||||
|
|||||||
@ -54,6 +54,7 @@ class AppConfig:
|
|||||||
decision_method: str = "nash"
|
decision_method: str = "nash"
|
||||||
data_paths: DataPaths = field(default_factory=DataPaths)
|
data_paths: DataPaths = field(default_factory=DataPaths)
|
||||||
agent_weights: AgentWeights = field(default_factory=AgentWeights)
|
agent_weights: AgentWeights = field(default_factory=AgentWeights)
|
||||||
|
force_refresh: bool = False
|
||||||
|
|
||||||
|
|
||||||
CONFIG = AppConfig()
|
CONFIG = AppConfig()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user