update
This commit is contained in:
parent
36322f66db
commit
5ef90b8de0
@ -37,6 +37,7 @@ from app.data.schema import initialize_database
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
API_DEFAULT_LIMIT = 5000
|
||||
LOG_EXTRA = {"stage": "data_ingest"}
|
||||
|
||||
|
||||
def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None = None) -> pd.DataFrame:
|
||||
@ -45,18 +46,41 @@ def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None
|
||||
frames: List[pd.DataFrame] = []
|
||||
offset = 0
|
||||
clean_params = {k: v for k, v in params.items() if v is not None}
|
||||
LOGGER.info("开始调用 TuShare 接口:%s,参数=%s,limit=%s", endpoint, clean_params, limit)
|
||||
LOGGER.info(
|
||||
"开始调用 TuShare 接口:%s,参数=%s,limit=%s",
|
||||
endpoint,
|
||||
clean_params,
|
||||
limit,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
while True:
|
||||
call = getattr(client, endpoint)
|
||||
try:
|
||||
df = call(limit=limit, offset=offset, **clean_params)
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.exception("TuShare 接口调用异常:endpoint=%s offset=%s params=%s", endpoint, offset, clean_params)
|
||||
LOGGER.exception(
|
||||
"TuShare 接口调用异常:endpoint=%s offset=%s params=%s",
|
||||
endpoint,
|
||||
offset,
|
||||
clean_params,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
raise
|
||||
if df is None or df.empty:
|
||||
LOGGER.info("TuShare 返回空数据:endpoint=%s offset=%s", endpoint, offset)
|
||||
LOGGER.debug(
|
||||
"TuShare 返回空数据:endpoint=%s offset=%s",
|
||||
endpoint,
|
||||
offset,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
break
|
||||
LOGGER.info("TuShare 返回 %s 行:endpoint=%s offset=%s", len(df), endpoint, offset)
|
||||
LOGGER.debug(
|
||||
"TuShare 返回 %s 行:endpoint=%s offset=%s",
|
||||
len(df),
|
||||
endpoint,
|
||||
offset,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
frames.append(df)
|
||||
if len(df) < limit:
|
||||
break
|
||||
@ -64,7 +88,12 @@ def _fetch_paginated(endpoint: str, params: Dict[str, object], limit: int | None
|
||||
if not frames:
|
||||
return pd.DataFrame()
|
||||
merged = pd.concat(frames, ignore_index=True)
|
||||
LOGGER.info("TuShare 调用完成:endpoint=%s 总行数=%s", endpoint, len(merged))
|
||||
LOGGER.info(
|
||||
"TuShare 调用完成:endpoint=%s 总行数=%s",
|
||||
endpoint,
|
||||
len(merged),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
@ -265,6 +294,13 @@ def _format_date(value: date) -> str:
|
||||
return value.strftime("%Y%m%d")
|
||||
|
||||
|
||||
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
|
||||
if df is None or df.empty:
|
||||
return []
|
||||
reindexed = df.reindex(columns=allowed_cols)
|
||||
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
|
||||
|
||||
|
||||
def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str]:
|
||||
start_str = _format_date(start)
|
||||
end_str = _format_date(end)
|
||||
@ -277,14 +313,15 @@ def _load_trade_dates(start: date, end: date, exchange: str = "SSE") -> List[str
|
||||
return [row["cal_date"] for row in rows]
|
||||
|
||||
|
||||
|
||||
def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool:
|
||||
min_d, max_d = _existing_date_range(table, date_col, ts_code)
|
||||
if min_d is None or max_d is None:
|
||||
return False
|
||||
start_str = _format_date(start)
|
||||
end_str = _format_date(end)
|
||||
return min_d <= start_str and max_d >= end_str
|
||||
def _daily_basic_exists(trade_date: str, ts_code: Optional[str] = None) -> bool:
|
||||
query = "SELECT 1 FROM daily_basic WHERE trade_date = ?"
|
||||
params: Tuple = (trade_date,)
|
||||
if ts_code:
|
||||
query += " AND ts_code = ?"
|
||||
params = (trade_date, ts_code)
|
||||
with db_session(read_only=True) as conn:
|
||||
row = conn.execute(query, params).fetchone()
|
||||
return row is not None
|
||||
|
||||
|
||||
def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_code: str | None = None) -> bool:
|
||||
@ -296,13 +333,6 @@ def _should_skip_range(table: str, date_col: str, start: date, end: date, ts_cod
|
||||
return min_d <= start_str and max_d >= end_str
|
||||
|
||||
|
||||
def _df_to_records(df: pd.DataFrame, allowed_cols: List[str]) -> List[Dict]:
|
||||
if df is None or df.empty:
|
||||
return []
|
||||
reindexed = df.reindex(columns=allowed_cols)
|
||||
return reindexed.where(pd.notnull(reindexed), None).to_dict("records")
|
||||
|
||||
|
||||
def _range_stats(table: str, date_col: str, start_str: str, end_str: str) -> Dict[str, Optional[str]]:
|
||||
sql = (
|
||||
f"SELECT MIN({date_col}) AS min_d, MAX({date_col}) AS max_d, "
|
||||
@ -408,18 +438,59 @@ def fetch_daily_bars(job: FetchJob) -> Iterable[Dict]:
|
||||
return _df_to_records(df, _TABLE_COLUMNS["daily"])
|
||||
|
||||
|
||||
def fetch_daily_basic(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
|
||||
def fetch_daily_basic(
|
||||
start: date,
|
||||
end: date,
|
||||
ts_code: Optional[str] = None,
|
||||
skip_existing: bool = True,
|
||||
) -> Iterable[Dict]:
|
||||
client = _ensure_client()
|
||||
start_date = _format_date(start)
|
||||
end_date = _format_date(end)
|
||||
LOGGER.info("拉取日线基础指标(%s-%s,股票:%s)", start_date, end_date, ts_code or "全部")
|
||||
df = _fetch_paginated("daily_basic", {
|
||||
LOGGER.info(
|
||||
"拉取日线基础指标(%s-%s,股票:%s)",
|
||||
start_date,
|
||||
end_date,
|
||||
ts_code or "全部",
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
if ts_code:
|
||||
df = _fetch_paginated(
|
||||
"daily_basic",
|
||||
{
|
||||
"ts_code": ts_code,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
})
|
||||
},
|
||||
)
|
||||
return _df_to_records(df, _TABLE_COLUMNS["daily_basic"])
|
||||
|
||||
trade_dates = _load_trade_dates(start, end)
|
||||
frames: List[pd.DataFrame] = []
|
||||
for trade_date in trade_dates:
|
||||
if skip_existing and _daily_basic_exists(trade_date):
|
||||
LOGGER.info(
|
||||
"日线基础指标已存在,跳过交易日 %s",
|
||||
trade_date,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
LOGGER.debug(
|
||||
"按交易日拉取日线基础指标:%s",
|
||||
trade_date,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
df = _fetch_paginated("daily_basic", {"trade_date": trade_date})
|
||||
if not df.empty:
|
||||
frames.append(df)
|
||||
|
||||
if not frames:
|
||||
return []
|
||||
|
||||
merged = pd.concat(frames, ignore_index=True)
|
||||
return _df_to_records(merged, _TABLE_COLUMNS["daily_basic"])
|
||||
|
||||
|
||||
def fetch_adj_factor(start: date, end: date, ts_code: Optional[str] = None) -> Iterable[Dict]:
|
||||
client = _ensure_client()
|
||||
@ -582,7 +653,10 @@ def ensure_data_coverage(
|
||||
continue
|
||||
LOGGER.info("拉取 %s 表数据(股票:%s)%s-%s", table, code, start_str, end_str)
|
||||
try:
|
||||
rows = fetch_fn(start, end, ts_code=code)
|
||||
kwargs = {"ts_code": code}
|
||||
if fetch_fn is fetch_daily_basic:
|
||||
kwargs["skip_existing"] = not force
|
||||
rows = fetch_fn(start, end, **kwargs)
|
||||
except Exception:
|
||||
LOGGER.exception("TuShare 拉取失败:table=%s code=%s", table, code)
|
||||
raise
|
||||
@ -597,7 +671,10 @@ def ensure_data_coverage(
|
||||
return
|
||||
LOGGER.info("拉取 %s 表数据(全市场)%s-%s", table, start_str, end_str)
|
||||
try:
|
||||
rows = fetch_fn(start, end)
|
||||
kwargs = {}
|
||||
if fetch_fn is fetch_daily_basic:
|
||||
kwargs["skip_existing"] = not force
|
||||
rows = fetch_fn(start, end, **kwargs)
|
||||
except Exception:
|
||||
LOGGER.exception("TuShare 拉取失败:table=%s code=全部", table)
|
||||
raise
|
||||
|
||||
@ -1,28 +1,112 @@
|
||||
"""Centralized logging configuration."""
|
||||
"""项目级日志配置模块。
|
||||
|
||||
提供统一的日志初始化入口,支持同时输出到终端、文件以及 SQLite
|
||||
数据库中的 `run_log` 表。数据库写入便于在 UI 中或离线复盘时查看
|
||||
运行轨迹。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from logging import Handler, LogRecord
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .config import get_config
|
||||
from .db import db_session
|
||||
|
||||
_LOGGER_NAME = "app.logging"
|
||||
_IS_CONFIGURED = False
|
||||
|
||||
|
||||
def configure_logging(level: int = logging.INFO) -> None:
|
||||
"""Setup root logger with file and console handlers."""
|
||||
class DatabaseLogHandler(Handler):
|
||||
"""将日志写入 SQLite `run_log` 表的自定义 Handler。"""
|
||||
|
||||
def emit(self, record: LogRecord) -> None: # noqa: D401 - 标准 logging 接口
|
||||
try:
|
||||
message = self.format(record)
|
||||
stage = getattr(record, "stage", None)
|
||||
ts = datetime.utcnow().isoformat(timespec="microseconds") + "Z"
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO run_log (ts, stage, level, msg) VALUES (?, ?, ?, ?)",
|
||||
(ts, stage, record.levelname, message),
|
||||
)
|
||||
except sqlite3.OperationalError as exc:
|
||||
# 表不存在时直接跳过,避免首次初始化阶段报错
|
||||
if "no such table" not in str(exc).lower():
|
||||
self.handleError(record)
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
def _build_formatter() -> logging.Formatter:
|
||||
return logging.Formatter("%(asctime)s %(levelname)s %(name)s - %(message)s")
|
||||
|
||||
|
||||
def setup_logging(
|
||||
*,
|
||||
level: int = logging.INFO,
|
||||
console_level: Optional[int] = None,
|
||||
file_level: Optional[int] = None,
|
||||
db_level: Optional[int] = None,
|
||||
) -> logging.Logger:
|
||||
"""配置根 logger。重复调用时将复用已存在的配置。"""
|
||||
|
||||
global _IS_CONFIGURED
|
||||
if _IS_CONFIGURED:
|
||||
return logging.getLogger()
|
||||
|
||||
env_level = os.getenv("LLM_QUANT_LOG_LEVEL")
|
||||
if env_level is None:
|
||||
level = logging.DEBUG
|
||||
else:
|
||||
try:
|
||||
level = getattr(logging, env_level.upper())
|
||||
except AttributeError:
|
||||
logging.getLogger(_LOGGER_NAME).warning(
|
||||
"非法的日志级别 %s,回退到 DEBUG", env_level
|
||||
)
|
||||
level = logging.DEBUG
|
||||
|
||||
cfg = get_config()
|
||||
log_dir = cfg.data_paths.root / "logs"
|
||||
log_dir: Path = cfg.data_paths.root / "logs"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
logfile = log_dir / "app.log"
|
||||
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s %(levelname)s %(name)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(logfile, encoding="utf-8"),
|
||||
logging.StreamHandler(),
|
||||
],
|
||||
)
|
||||
root = logging.getLogger()
|
||||
root.setLevel(level)
|
||||
root.handlers.clear()
|
||||
|
||||
formatter = _build_formatter()
|
||||
|
||||
console_handler = logging.StreamHandler(stream=sys.stdout)
|
||||
console_handler.setLevel(console_level or level)
|
||||
console_handler.setFormatter(formatter)
|
||||
root.addHandler(console_handler)
|
||||
|
||||
file_handler = logging.FileHandler(logfile, encoding="utf-8")
|
||||
file_handler.setLevel(file_level or level)
|
||||
file_handler.setFormatter(formatter)
|
||||
root.addHandler(file_handler)
|
||||
|
||||
db_handler = DatabaseLogHandler(level=db_level or level)
|
||||
db_handler.setFormatter(formatter)
|
||||
root.addHandler(db_handler)
|
||||
|
||||
_IS_CONFIGURED = True
|
||||
return root
|
||||
|
||||
|
||||
configure_logging()
|
||||
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
||||
"""返回指定名称的 logger,确保全局配置已就绪。"""
|
||||
|
||||
setup_logging()
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
# 默认在模块导入时完成配置,适配现有调用方式。
|
||||
setup_logging()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user