This commit is contained in:
sam 2025-09-27 08:34:11 +08:00
parent 36322f66db
commit 5ef90b8de0
2 changed files with 204 additions and 43 deletions

View File

@ -37,6 +37,7 @@ from app.data.schema import initialize_database
LOGGER = logging.getLogger(__name__)
API_DEFAULT_LIMIT = 5000
LOG_EXTRA = {"stage": "data_ingest"}
def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None = None) -> pd.DataFrame:
@ -45,18 +46,41 @@ def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None
frames: List[pd.DataFrame] = []
offset = 0
clean_params = {k: v for k, v in params.items() if v is not None}
LOGGER.info("开始调用 TuShare 接口:%s,参数=%slimit=%s", endpoint, clean_params, limit)
LOGGER.info(
"开始调用 TuShare 接口:%s,参数=%slimit=%s",
endpoint,
clean_params,
limit,
extra=LOG_EXTRA,
)
while True:
call = getattr(client, endpoint)
try:
df = call(limit=limit, offset=offset, **clean_params)
except Exception: # noqa: BLE001
LOGGER.exception("TuShare 接口调用异常endpoint=%s offset=%s params=%s", endpoint, offset, clean_params)
LOGGER.exception(
"TuShare 接口调用异常endpoint=%s offset=%s params=%s",
endpoint,
offset,
clean_params,
extra=LOG_EXTRA,
)
raise
if df is None or df.empty:
LOGGER.info("TuShare 返回空数据endpoint=%s offset=%s", endpoint, offset)
LOGGER.debug(
"TuShare 返回空数据endpoint=%s offset=%s",
endpoint,
offset,
extra=LOG_EXTRA,
)
break
LOGGER.info("TuShare 返回 %sendpoint=%s offset=%s", len(df), endpoint, offset)
LOGGER.debug(
"TuShare 返回 %sendpoint=%s offset=%s",
len(df),
endpoint,
offset,
extra=LOG_EXTRA,
)
frames.append(df)
if len(df) < limit:
break
@ -64,7 +88,12 @@ def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None
if not frames:
return pd.DataFrame()
merged = pd.concat(frames, ignore_index=True)
LOGGER.info("TuShare 调用完成endpoint=%s 总行数=%s", endpoint, len(merged))
LOGGER.info(
"TuShare 调用完成endpoint=%s 总行数=%s",
endpoint,
len(merged),
extra=LOG_EXTRA,
)
return merged
@ -265,6 +294,13 @@ def _format_date(value: date) -> str:
return value.strftime("%Y%m%d")
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
if df is None or df.empty:
return []
reindexed = df.reindex(columns=allowed_cols)
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str]:
start_str = _format_date(start)
end_str = _format_date(end)
@ -277,14 +313,15 @@ def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str
return [row["cal_date"] for row in rows]
def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool:
min_d, max_d = _existing_date_range(table, date_col, ts_code)
if min_d is None or max_d is None:
return False
start_str = _format_date(start)
end_str = _format_date(end)
return min_d <= start_str and max_d >= end_str
def _daily_basic_exists(trade_date: str, ts_code: Optional[str] = None) -> bool:
query = "SELECT 1 FROM daily_basic WHERE trade_date = ?"
params: Tuple = (trade_date,)
if ts_code:
query += " AND ts_code = ?"
params = (trade_date, ts_code)
with db_session(read_only=True) as conn:
row = conn.execute(query, params).fetchone()
return row is not None
def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool:
@ -296,13 +333,6 @@ def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_cod
return min_d <= start_str and max_d >= end_str
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
if df is None or df.empty:
return []
reindexed = df.reindex(columns=allowed_cols)
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
def _range_stats(table: str, date_col: str, start_str: str, end_str: str) -> Dict[str, Optional[str]]:
sql = (
f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d, "
@ -408,17 +438,58 @@ def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]:
return _df_to_records(df, _TABLE_COLUMNS["daily"])
def fetch_daily_basic(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
def fetch_daily_basic(
start: date,
end: date,
ts_code: Optional[str] = None,
skip_existing: bool = True,
) -> Iterable[Dict]:
client = _ensure_client()
start_date = _format_date(start)
end_date = _format_date(end)
LOGGER.info("拉取日线基础指标(%s-%s,股票:%s", start_date, end_date, ts_code or "全部")
df = _fetch_paginated("daily_basic", {
"ts_code": ts_code,
"start_date": start_date,
"end_date": end_date,
})
return _df_to_records(df, _TABLE_COLUMNS["daily_basic"])
LOGGER.info(
"拉取日线基础指标(%s-%s,股票:%s",
start_date,
end_date,
ts_code or "全部",
extra=LOG_EXTRA,
)
if ts_code:
df = _fetch_paginated(
"daily_basic",
{
"ts_code": ts_code,
"start_date": start_date,
"end_date": end_date,
},
)
return _df_to_records(df, _TABLE_COLUMNS["daily_basic"])
trade_dates = _load_trade_dates(start, end)
frames: List[pd.DataFrame] = []
for trade_date in trade_dates:
if skip_existing and _daily_basic_exists(trade_date):
LOGGER.info(
"日线基础指标已存在,跳过交易日 %s",
trade_date,
extra=LOG_EXTRA,
)
continue
LOGGER.debug(
"按交易日拉取日线基础指标:%s",
trade_date,
extra=LOG_EXTRA,
)
df = _fetch_paginated("daily_basic", {"trade_date": trade_date})
if not df.empty:
frames.append(df)
if not frames:
return []
merged = pd.concat(frames, ignore_index=True)
return _df_to_records(merged, _TABLE_COLUMNS["daily_basic"])
def fetch_adj_factor(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
@ -582,7 +653,10 @@ def ensure_data_coverage(
continue
LOGGER.info("拉取 %s 表数据(股票:%s%s-%s", table, code, start_str, end_str)
try:
rows = fetch_fn(start, end, ts_code=code)
kwargs = {"ts_code": code}
if fetch_fn is fetch_daily_basic:
kwargs["skip_existing"] = not force
rows = fetch_fn(start, end, **kwargs)
except Exception:
LOGGER.exception("TuShare 拉取失败table=%s code=%s", table, code)
raise
@ -597,7 +671,10 @@ def ensure_data_coverage(
return
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
try:
rows = fetch_fn(start, end)
kwargs = {}
if fetch_fn is fetch_daily_basic:
kwargs["skip_existing"] = not force
rows = fetch_fn(start, end, **kwargs)
except Exception:
LOGGER.exception("TuShare 拉取失败table=%s code=全部", table)
raise

View File

@ -1,28 +1,112 @@
"""Centralized logging configuration."""
"""项目级日志配置模块。
提供统一的日志初始化入口支持同时输出到终端文件以及 SQLite
数据库中的 `run_log` 数据库写入便于在 UI 中或离线复盘时查看
运行轨迹
"""
from __future__ import annotations
import logging
import os
import sqlite3
import sys
from datetime import datetime
from logging import Handler, LogRecord
from pathlib import Path
from typing import Optional
from .config import get_config
from .db import db_session
_LOGGER_NAME = "app.logging"
_IS_CONFIGURED = False
def configure_logging(level: int = logging.INFO) -> None:
"""Setup root logger with file and console handlers."""
class DatabaseLogHandler(Handler):
"""将日志写入 SQLite `run_log` 表的自定义 Handler。"""
def emit(self, record: LogRecord) -> None: # noqa: D401 - 标准 logging 接口
try:
message = self.format(record)
stage = getattr(record, "stage", None)
ts = datetime.utcnow().isoformat(timespec="microseconds") + "Z"
with db_session() as conn:
conn.execute(
"INSERT INTO run_log (ts, stage, level, msg) VALUES (?, ?, ?, ?)",
(ts, stage, record.levelname, message),
)
except sqlite3.OperationalError as exc:
# 表不存在时直接跳过,避免首次初始化阶段报错
if "no such table" not in str(exc).lower():
self.handleError(record)
except Exception:
self.handleError(record)
def _build_formatter() -> logging.Formatter:
return logging.Formatter("%(asctime)s %(levelname)s %(name)s - %(message)s")
def setup_logging(
*,
level: int = logging.INFO,
console_level: Optional[int] = None,
file_level: Optional[int] = None,
db_level: Optional[int] = None,
) -> logging.Logger:
"""配置根 logger。重复调用时将复用已存在的配置。"""
global _IS_CONFIGURED
if _IS_CONFIGURED:
return logging.getLogger()
env_level = os.getenv("LLM_QUANT_LOG_LEVEL")
if env_level is None:
level = logging.DEBUG
else:
try:
level = getattr(logging, env_level.upper())
except AttributeError:
logging.getLogger(_LOGGER_NAME).warning(
"非法的日志级别 %s,回退到 DEBUG", env_level
)
level = logging.DEBUG
cfg = get_config()
log_dir = cfg.data_paths.root / "logs"
log_dir: Path = cfg.data_paths.root / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
logfile = log_dir / "app.log"
logging.basicConfig(
level=level,
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
handlers=[
logging.FileHandler(logfile, encoding="utf-8"),
logging.StreamHandler(),
],
)
root = logging.getLogger()
root.setLevel(level)
root.handlers.clear()
formatter = _build_formatter()
console_handler = logging.StreamHandler(stream=sys.stdout)
console_handler.setLevel(console_level or level)
console_handler.setFormatter(formatter)
root.addHandler(console_handler)
file_handler = logging.FileHandler(logfile, encoding="utf-8")
file_handler.setLevel(file_level or level)
file_handler.setFormatter(formatter)
root.addHandler(file_handler)
db_handler = DatabaseLogHandler(level=db_level or level)
db_handler.setFormatter(formatter)
root.addHandler(db_handler)
_IS_CONFIGURED = True
return root
configure_logging()
def get_logger(name: Optional[str] = None) -> logging.Logger:
"""返回指定名称的 logger确保全局配置已就绪。"""
setup_logging()
return logging.getLogger(name)
# 默认在模块导入时完成配置,适配现有调用方式。
setup_logging()