llm-quant/app/features/factor_audit.py

233 lines
7.3 KiB
Python

"""Utilities for auditing persisted factor values against live formulas."""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from datetime import date
from typing import Dict, List, Mapping, Optional, Sequence
from app.features.factors import (
DEFAULT_FACTORS,
FactorResult,
FactorSpec,
compute_factors,
lookup_factor_spec,
)
from app.utils.db import db_session
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "factor_audit"}
@dataclass
class FactorAuditIssue:
"""Details for a single factor mismatch discovered during auditing."""
ts_code: str
factor: str
stored: Optional[float]
recomputed: Optional[float]
difference: Optional[float]
@dataclass
class FactorAuditSummary:
"""Aggregated results for a factor audit run."""
trade_date: date
tolerance: float
factor_names: List[str]
total_persisted: int
total_recomputed: int
evaluated: int
mismatched: int
missing_persisted: int
missing_recomputed: int
missing_columns: List[str] = field(default_factory=list)
issues: List[FactorAuditIssue] = field(default_factory=list)
def audit_factors(
trade_date: date,
*,
factors: Optional[Sequence[str | FactorSpec]] = None,
tolerance: float = 1e-6,
max_issues: int = 50,
) -> FactorAuditSummary:
"""Recompute factor values and compare them with persisted records.
Args:
trade_date: 需要审计的交易日
factors: 因子名称或 ``FactorSpec`` 序列,缺省为默认因子集合
tolerance: 比较阈值,超出视为不一致
max_issues: 限制返回的详细问题数量
Returns:
FactorAuditSummary: 审计结果摘要
"""
specs = _resolve_factor_specs(factors)
factor_names = [spec.name for spec in specs]
trade_date_str = trade_date.strftime("%Y%m%d")
persisted_map, missing_columns = _load_persisted_factors(trade_date_str, factor_names)
recomputed_results = compute_factors(
trade_date,
specs,
persist=False,
)
recomputed_map = {result.ts_code: result.values for result in recomputed_results}
mismatched = 0
evaluated = 0
missing_persisted = 0
issues: List[FactorAuditIssue] = []
for ts_code, values in recomputed_map.items():
stored = persisted_map.get(ts_code)
if not stored:
missing_persisted += 1
LOGGER.debug(
"审计未找到持久化记录 ts_code=%s trade_date=%s",
ts_code,
trade_date_str,
extra=LOG_EXTRA,
)
continue
evaluated += 1
for factor in factor_names:
recomputed_value = values.get(factor)
stored_value = stored.get(factor)
numeric_recomputed = _coerce_float(recomputed_value)
numeric_stored = _coerce_float(stored_value)
if numeric_recomputed is None and numeric_stored is None:
continue
if numeric_recomputed is None or numeric_stored is None:
mismatched += 1
if len(issues) < max_issues:
issues.append(
FactorAuditIssue(
ts_code=ts_code,
factor=factor,
stored=stored_value,
recomputed=recomputed_value,
difference=None,
)
)
continue
diff = abs(numeric_recomputed - numeric_stored)
if math.isnan(diff) or diff > tolerance:
mismatched += 1
if len(issues) < max_issues:
issues.append(
FactorAuditIssue(
ts_code=ts_code,
factor=factor,
stored=numeric_stored,
recomputed=numeric_recomputed,
difference=diff if not math.isnan(diff) else None,
)
)
missing_recomputed = len(
{code for code in persisted_map.keys() if code not in recomputed_map}
)
summary = FactorAuditSummary(
trade_date=trade_date,
tolerance=tolerance,
factor_names=factor_names,
total_persisted=len(persisted_map),
total_recomputed=len(recomputed_map),
evaluated=evaluated,
mismatched=mismatched,
missing_persisted=missing_persisted,
missing_recomputed=missing_recomputed,
missing_columns=missing_columns,
issues=issues,
)
LOGGER.info(
"因子审计完成 trade_date=%s evaluated=%s mismatched=%s missing=%s/%s",
trade_date_str,
evaluated,
mismatched,
missing_persisted,
missing_recomputed,
extra=LOG_EXTRA,
)
if missing_columns:
LOGGER.warning(
"因子审计缺少字段 columns=%s trade_date=%s",
missing_columns,
trade_date_str,
extra=LOG_EXTRA,
)
return summary
def _resolve_factor_specs(
factors: Optional[Sequence[str | FactorSpec]],
) -> List[FactorSpec]:
if not factors:
return list(DEFAULT_FACTORS)
resolved: Dict[str, FactorSpec] = {}
for item in factors:
if isinstance(item, FactorSpec):
resolved[item.name] = FactorSpec(name=item.name, window=item.window)
continue
spec = lookup_factor_spec(str(item))
if spec is None:
LOGGER.debug("忽略未知因子,无法审计 factor=%s", item, extra=LOG_EXTRA)
continue
resolved[spec.name] = spec
return list(resolved.values()) or list(DEFAULT_FACTORS)
def _load_persisted_factors(
trade_date: str,
factor_names: Sequence[str],
) -> tuple[Dict[str, Dict[str, Optional[float]]], List[str]]:
if not factor_names:
return {}, []
with db_session(read_only=True) as conn:
table_info = conn.execute("PRAGMA table_info(factors)").fetchall()
available_columns: set[str] = set()
for row in table_info:
if isinstance(row, Mapping):
available_columns.add(str(row.get("name")))
else:
available_columns.add(str(row[1]))
selected = [name for name in factor_names if name in available_columns]
missing_columns = [name for name in factor_names if name not in available_columns]
if not selected:
return {}, missing_columns
column_clause = ", ".join(["ts_code", *selected])
query = f"SELECT {column_clause} FROM factors WHERE trade_date = ?"
rows = conn.execute(query, (trade_date,)).fetchall()
persisted: Dict[str, Dict[str, Optional[float]]] = {}
for row in rows:
ts_code = row["ts_code"]
persisted[ts_code] = {name: row[name] for name in selected}
return persisted, missing_columns
def _coerce_float(value: object) -> Optional[float]:
if value is None:
return None
try:
numeric = float(value)
except (TypeError, ValueError):
return None
if math.isnan(numeric) or not math.isfinite(numeric):
return None
return numeric
__all__ = [
"FactorAuditIssue",
"FactorAuditSummary",
"audit_factors",
]