diff --git a/app/ingest/tushare.py b/app/ingest/tushare.py index 1f9985f..2c897d5 100644 --- a/app/ingest/tushare.py +++ b/app/ingest/tushare.py @@ -37,6 +37,7 @@ from app.data.schema import initialize_database LOGGER = logging.getLogger(__name__) API_DEFAULT_LIMIT = 5000 +LOG_EXTRA = {"stage": "data_ingest"} def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None = None) -> pd.DataFrame: @@ -45,18 +46,41 @@ def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None frames: List[pd.DataFrame] = [] offset = 0 clean_params = {k: v for k, v in params.items() if v is not None} - LOGGER.info("开始调用 TuShare 接口:%s,参数=%s,limit=%s", endpoint, clean_params, limit) + LOGGER.info( + "开始调用 TuShare 接口:%s,参数=%s,limit=%s", + endpoint, + clean_params, + limit, + extra=LOG_EXTRA, + ) while True: call = getattr(client, endpoint) try: df = call(limit=limit, offset=offset, **clean_params) except Exception: # noqa: BLE001 - LOGGER.exception("TuShare 接口调用异常:endpoint=%s offset=%s params=%s", endpoint, offset, clean_params) + LOGGER.exception( + "TuShare 接口调用异常:endpoint=%s offset=%s params=%s", + endpoint, + offset, + clean_params, + extra=LOG_EXTRA, + ) raise if df is None or df.empty: - LOGGER.info("TuShare 返回空数据:endpoint=%s offset=%s", endpoint, offset) + LOGGER.debug( + "TuShare 返回空数据:endpoint=%s offset=%s", + endpoint, + offset, + extra=LOG_EXTRA, + ) break - LOGGER.info("TuShare 返回 %s 行:endpoint=%s offset=%s", len(df), endpoint, offset) + LOGGER.debug( + "TuShare 返回 %s 行:endpoint=%s offset=%s", + len(df), + endpoint, + offset, + extra=LOG_EXTRA, + ) frames.append(df) if len(df) < limit: break @@ -64,7 +88,12 @@ def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None if not frames: return pd.DataFrame() merged = pd.concat(frames, ignore_index=True) - LOGGER.info("TuShare 调用完成:endpoint=%s 总行数=%s", endpoint, len(merged)) + LOGGER.info( + "TuShare 调用完成:endpoint=%s 总行数=%s", + endpoint, + len(merged), + extra=LOG_EXTRA, + ) return merged @@ -265,6 +294,13 @@ def _format_date(value: date) -> str: return value.strftime("%Y%m%d") +def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]: + if df is None or df.empty: + return [] + reindexed = df.reindex(columns=allowed_cols) + return reindexed.where(pd.notnull(reindexed), None).to_dict("records") + + def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str]: start_str = _format_date(start) end_str = _format_date(end) @@ -277,14 +313,15 @@ def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str return [row["cal_date"] for row in rows] - -def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool: - min_d, max_d = _existing_date_range(table, date_col, ts_code) - if min_d is None or max_d is None: - return False - start_str = _format_date(start) - end_str = _format_date(end) - return min_d <= start_str and max_d >= end_str +def _daily_basic_exists(trade_date: str, ts_code: Optional[str] = None) -> bool: + query = "SELECT 1 FROM daily_basic WHERE trade_date = ?" + params: Tuple = (trade_date,) + if ts_code: + query += " AND ts_code = ?" + params = (trade_date, ts_code) + with db_session(read_only=True) as conn: + row = conn.execute(query, params).fetchone() + return row is not None def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool: @@ -296,13 +333,6 @@ def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_cod return min_d <= start_str and max_d >= end_str -def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]: - if df is None or df.empty: - return [] - reindexed = df.reindex(columns=allowed_cols) - return reindexed.where(pd.notnull(reindexed), None).to_dict("records") - - def _range_stats(table: str, date_col: str, start_str: str, end_str: str) -> Dict[str, Optional[str]]: sql = ( f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d, " @@ -408,17 +438,58 @@ def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]: return _df_to_records(df, _TABLE_COLUMNS["daily"]) -def fetch_daily_basic(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]: +def fetch_daily_basic( + start: date, + end: date, + ts_code: Optional[str] = None, + skip_existing: bool = True, +) -> Iterable[Dict]: client = _ensure_client() start_date = _format_date(start) end_date = _format_date(end) - LOGGER.info("拉取日线基础指标(%s-%s,股票:%s)", start_date, end_date, ts_code or "全部") - df = _fetch_paginated("daily_basic", { - "ts_code": ts_code, - "start_date": start_date, - "end_date": end_date, - }) - return _df_to_records(df, _TABLE_COLUMNS["daily_basic"]) + LOGGER.info( + "拉取日线基础指标(%s-%s,股票:%s)", + start_date, + end_date, + ts_code or "全部", + extra=LOG_EXTRA, + ) + + if ts_code: + df = _fetch_paginated( + "daily_basic", + { + "ts_code": ts_code, + "start_date": start_date, + "end_date": end_date, + }, + ) + return _df_to_records(df, _TABLE_COLUMNS["daily_basic"]) + + trade_dates = _load_trade_dates(start, end) + frames: List[pd.DataFrame] = [] + for trade_date in trade_dates: + if skip_existing and _daily_basic_exists(trade_date): + LOGGER.info( + "日线基础指标已存在,跳过交易日 %s", + trade_date, + extra=LOG_EXTRA, + ) + continue + LOGGER.debug( + "按交易日拉取日线基础指标:%s", + trade_date, + extra=LOG_EXTRA, + ) + df = _fetch_paginated("daily_basic", {"trade_date": trade_date}) + if not df.empty: + frames.append(df) + + if not frames: + return [] + + merged = pd.concat(frames, ignore_index=True) + return _df_to_records(merged, _TABLE_COLUMNS["daily_basic"]) def fetch_adj_factor(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]: @@ -582,7 +653,10 @@ def ensure_data_coverage( continue LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str) try: - rows = fetch_fn(start, end, ts_code=code) + kwargs = {"ts_code": code} + if fetch_fn is fetch_daily_basic: + kwargs["skip_existing"] = not force + rows = fetch_fn(start, end, **kwargs) except Exception: LOGGER.exception("TuShare 拉取失败:table=%s code=%s", table, code) raise @@ -597,7 +671,10 @@ def ensure_data_coverage( return LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str) try: - rows = fetch_fn(start, end) + kwargs = {} + if fetch_fn is fetch_daily_basic: + kwargs["skip_existing"] = not force + rows = fetch_fn(start, end, **kwargs) except Exception: LOGGER.exception("TuShare 拉取失败:table=%s code=全部", table) raise diff --git a/app/utils/logging.py b/app/utils/logging.py index 687a850..783697b 100644 --- a/app/utils/logging.py +++ b/app/utils/logging.py @@ -1,28 +1,112 @@ -"""Centralized logging configuration.""" +"""项目级日志配置模块。 + +提供统一的日志初始化入口,支持同时输出到终端、文件以及 SQLite +数据库中的 `run_log` 表。数据库写入便于在 UI 中或离线复盘时查看 +运行轨迹。 +""" from __future__ import annotations import logging +import os +import sqlite3 +import sys +from datetime import datetime +from logging import Handler, LogRecord from pathlib import Path +from typing import Optional from .config import get_config +from .db import db_session + +_LOGGER_NAME = "app.logging" +_IS_CONFIGURED = False -def configure_logging(level: int = logging.INFO) -> None: - """Setup root logger with file and console handlers.""" +class DatabaseLogHandler(Handler): + """将日志写入 SQLite `run_log` 表的自定义 Handler。""" + + def emit(self, record: LogRecord) -> None: # noqa: D401 - 标准 logging 接口 + try: + message = self.format(record) + stage = getattr(record, "stage", None) + ts = datetime.utcnow().isoformat(timespec="microseconds") + "Z" + with db_session() as conn: + conn.execute( + "INSERT INTO run_log (ts, stage, level, msg) VALUES (?, ?, ?, ?)", + (ts, stage, record.levelname, message), + ) + except sqlite3.OperationalError as exc: + # 表不存在时直接跳过,避免首次初始化阶段报错 + if "no such table" not in str(exc).lower(): + self.handleError(record) + except Exception: + self.handleError(record) + + +def _build_formatter() -> logging.Formatter: + return logging.Formatter("%(asctime)s %(levelname)s %(name)s - %(message)s") + + +def setup_logging( + *, + level: int = logging.INFO, + console_level: Optional[int] = None, + file_level: Optional[int] = None, + db_level: Optional[int] = None, +) -> logging.Logger: + """配置根 logger。重复调用时将复用已存在的配置。""" + + global _IS_CONFIGURED + if _IS_CONFIGURED: + return logging.getLogger() + + env_level = os.getenv("LLM_QUANT_LOG_LEVEL") + if env_level is None: + level = logging.DEBUG + else: + try: + level = getattr(logging, env_level.upper()) + except AttributeError: + logging.getLogger(_LOGGER_NAME).warning( + "非法的日志级别 %s,回退到 DEBUG", env_level + ) + level = logging.DEBUG cfg = get_config() - log_dir = cfg.data_paths.root / "logs" + log_dir: Path = cfg.data_paths.root / "logs" log_dir.mkdir(parents=True, exist_ok=True) logfile = log_dir / "app.log" - logging.basicConfig( - level=level, - format="%(asctime)s %(levelname)s %(name)s - %(message)s", - handlers=[ - logging.FileHandler(logfile, encoding="utf-8"), - logging.StreamHandler(), - ], - ) + root = logging.getLogger() + root.setLevel(level) + root.handlers.clear() + + formatter = _build_formatter() + + console_handler = logging.StreamHandler(stream=sys.stdout) + console_handler.setLevel(console_level or level) + console_handler.setFormatter(formatter) + root.addHandler(console_handler) + + file_handler = logging.FileHandler(logfile, encoding="utf-8") + file_handler.setLevel(file_level or level) + file_handler.setFormatter(formatter) + root.addHandler(file_handler) + + db_handler = DatabaseLogHandler(level=db_level or level) + db_handler.setFormatter(formatter) + root.addHandler(db_handler) + + _IS_CONFIGURED = True + return root -configure_logging() +def get_logger(name: Optional[str] = None) -> logging.Logger: + """返回指定名称的 logger,确保全局配置已就绪。""" + + setup_logging() + return logging.getLogger(name) + + +# 默认在模块导入时完成配置,适配现有调用方式。 +setup_logging()