llm-quant/app/features/evaluation.py

342 lines
11 KiB
Python

"""Factor performance evaluation utilities."""
from datetime import date
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
from scipy import stats
from app.features.factors import (
DEFAULT_FACTORS,
FactorSpec,
lookup_factor_spec,
)
from app.utils.db import db_session
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "factor_evaluation"}
class FactorPerformance:
"""因子表现评估结果。"""
def __init__(self, factor_name: str) -> None:
self.factor_name = factor_name
self.ic_series: List[float] = []
self.rank_ic_series: List[float] = []
self.return_spreads: List[float] = []
self.sharpe_ratio: Optional[float] = None
self.turnover_rate: Optional[float] = None
self.sample_size: int = 0
@property
def ic_mean(self) -> float:
"""平均IC。"""
return np.mean(self.ic_series) if self.ic_series else 0.0
@property
def ic_std(self) -> float:
"""IC标准差。"""
return np.std(self.ic_series) if self.ic_series else 0.0
@property
def ic_ir(self) -> float:
"""信息比率。"""
return self.ic_mean / self.ic_std if self.ic_std > 0 else 0.0
@property
def rank_ic_mean(self) -> float:
"""平均RankIC。"""
return np.mean(self.rank_ic_series) if self.rank_ic_series else 0.0
def to_dict(self) -> Dict[str, float]:
"""转换为字典格式。"""
return {
"ic_mean": self.ic_mean,
"ic_std": self.ic_std,
"ic_ir": self.ic_ir,
"rank_ic_mean": self.rank_ic_mean,
"sharpe_ratio": self.sharpe_ratio or 0.0,
"turnover_rate": self.turnover_rate or 0.0,
"sample_size": float(self.sample_size),
}
def evaluate_factor(
factor_name: str,
start_date: date,
end_date: date,
universe: Optional[List[str]] = None,
) -> FactorPerformance:
"""评估单个因子的预测能力。
Args:
factor_name: 因子名称
start_date: 起始日期
end_date: 结束日期
universe: 可选的股票池
Returns:
因子表现评估结果
"""
performance = FactorPerformance(factor_name)
spec = lookup_factor_spec(factor_name)
factor_column = factor_name
if spec is None:
LOGGER.warning("未找到因子定义,仍尝试从数据库读取 factor=%s", factor_name, extra=LOG_EXTRA)
normalized_universe = _normalize_universe(universe)
start_str = start_date.strftime("%Y%m%d")
end_str = end_date.strftime("%Y%m%d")
with db_session(read_only=True) as conn:
if not _has_factor_column(conn, factor_column):
LOGGER.warning("factors 表缺少列 %s,跳过评估", factor_column, extra=LOG_EXTRA)
return performance
trade_dates = _list_factor_dates(conn, start_str, end_str, normalized_universe)
if not trade_dates:
LOGGER.info("指定区间内未找到可用因子数据 factor=%s", factor_name, extra=LOG_EXTRA)
return performance
usable_trade_dates: List[str] = []
for trade_date_str in trade_dates:
with db_session(read_only=True) as conn:
factor_map = _fetch_factor_cross_section(conn, factor_column, trade_date_str, normalized_universe)
if not factor_map:
continue
next_trade = _next_trade_date(conn, trade_date_str)
if not next_trade:
continue
curr_close = _fetch_close_map(conn, trade_date_str, factor_map.keys())
next_close = _fetch_close_map(conn, next_trade, factor_map.keys())
factor_values: List[float] = []
returns: List[float] = []
for ts_code, value in factor_map.items():
curr = curr_close.get(ts_code)
nxt = next_close.get(ts_code)
if curr is None or nxt is None or curr <= 0:
continue
factor_values.append(value)
returns.append((nxt - curr) / curr)
if len(factor_values) < 20:
continue
values_array = np.array(factor_values, dtype=float)
returns_array = np.array(returns, dtype=float)
if np.ptp(values_array) <= 1e-9 or np.ptp(returns_array) <= 1e-9:
LOGGER.debug(
"因子/收益序列波动不足,跳过 date=%s span_factor=%.6f span_return=%.6f",
trade_date_str,
float(np.ptp(values_array)),
float(np.ptp(returns_array)),
extra=LOG_EXTRA,
)
continue
try:
ic, _ = stats.pearsonr(values_array, returns_array)
rank_ic, _ = stats.spearmanr(values_array, returns_array)
except Exception as exc: # noqa: BLE001
LOGGER.debug("IC 计算失败 date=%s err=%s", trade_date_str, exc, extra=LOG_EXTRA)
continue
if not (np.isfinite(ic) and np.isfinite(rank_ic)):
LOGGER.debug(
"相关系数结果无效 date=%s ic=%s rank_ic=%s",
trade_date_str,
ic,
rank_ic,
extra=LOG_EXTRA,
)
continue
performance.ic_series.append(ic)
performance.rank_ic_series.append(rank_ic)
usable_trade_dates.append(trade_date_str)
sorted_pairs = sorted(zip(values_array.tolist(), returns_array.tolist()), key=lambda item: item[0])
quantile = len(sorted_pairs) // 5
if quantile > 0:
top_returns = [ret for _, ret in sorted_pairs[-quantile:]]
bottom_returns = [ret for _, ret in sorted_pairs[:quantile]]
spread = float(np.mean(top_returns) - np.mean(bottom_returns))
performance.return_spreads.append(spread)
if performance.return_spreads:
returns_mean = float(np.mean(performance.return_spreads))
returns_std = float(np.std(performance.return_spreads))
if returns_std > 0:
performance.sharpe_ratio = returns_mean / returns_std * np.sqrt(252.0)
performance.sample_size = len(usable_trade_dates)
performance.turnover_rate = _estimate_turnover_rate(
factor_column,
usable_trade_dates,
normalized_universe,
)
return performance
def combine_factors(
factor_names: Sequence[str],
weights: Optional[Sequence[float]] = None
) -> FactorSpec:
"""组合多个因子。
Args:
factor_names: 因子名称列表
weights: 可选的权重列表,默认等权重
Returns:
组合因子的规格
"""
if not weights:
weights = [1.0 / len(factor_names)] * len(factor_names)
name = "combined_" + "_".join(factor_names)
window = max(
spec.window
for spec in DEFAULT_FACTORS
if spec.name in factor_names
)
return FactorSpec(name, window)
def _normalize_universe(universe: Optional[Sequence[str]]) -> Optional[Tuple[str, ...]]:
if not universe:
return None
unique: Dict[str, None] = {}
for code in universe:
value = (code or "").strip().upper()
if value:
unique.setdefault(value, None)
return tuple(unique.keys()) if unique else None
def _has_factor_column(conn, column: str) -> bool:
rows = conn.execute("PRAGMA table_info(factors)").fetchall()
available = {row["name"] for row in rows}
return column in available
def _list_factor_dates(conn, start: str, end: str, universe: Optional[Tuple[str, ...]]) -> List[str]:
params: List[str] = [start, end]
query = (
"SELECT DISTINCT trade_date FROM factors "
"WHERE trade_date BETWEEN ? AND ?"
)
if universe:
placeholders = ",".join("?" for _ in universe)
query += f" AND ts_code IN ({placeholders})"
params.extend(universe)
query += " ORDER BY trade_date"
rows = conn.execute(query, params).fetchall()
return [row["trade_date"] for row in rows if row and row["trade_date"]]
def _fetch_factor_cross_section(
conn,
column: str,
trade_date: str,
universe: Optional[Tuple[str, ...]],
) -> Dict[str, float]:
params: List[str] = [trade_date]
query = f"SELECT ts_code, {column} AS value FROM factors WHERE trade_date = ? AND {column} IS NOT NULL"
if universe:
placeholders = ",".join("?" for _ in universe)
query += f" AND ts_code IN ({placeholders})"
params.extend(universe)
rows = conn.execute(query, params).fetchall()
result: Dict[str, float] = {}
for row in rows:
ts_code = row["ts_code"]
value = row["value"]
if ts_code is None or value is None:
continue
try:
numeric = float(value)
except (TypeError, ValueError):
continue
if not np.isfinite(numeric):
continue
result[ts_code] = numeric
return result
def _next_trade_date(conn, trade_date: str) -> Optional[str]:
row = conn.execute(
"SELECT MIN(trade_date) AS next_date FROM daily WHERE trade_date > ?",
(trade_date,),
).fetchone()
next_date = row["next_date"] if row else None
return next_date
def _fetch_close_map(conn, trade_date: str, codes: Sequence[str]) -> Dict[str, float]:
if not codes:
return {}
placeholders = ",".join("?" for _ in codes)
params = [trade_date, *codes]
rows = conn.execute(
f"""
SELECT ts_code, close
FROM daily
WHERE trade_date = ?
AND ts_code IN ({placeholders})
AND close IS NOT NULL
""",
params,
).fetchall()
result: Dict[str, float] = {}
for row in rows:
ts_code = row["ts_code"]
value = row["close"]
if ts_code is None or value is None:
continue
try:
result[ts_code] = float(value)
except (TypeError, ValueError):
continue
return result
def _estimate_turnover_rate(
factor_name: str,
trade_dates: Sequence[str],
universe: Optional[Tuple[str, ...]],
) -> Optional[float]:
if not trade_dates:
return None
turnovers: List[float] = []
for idx in range(1, len(trade_dates)):
prev_date = trade_dates[idx - 1]
curr_date = trade_dates[idx]
with db_session(read_only=True) as conn:
prev_map = _fetch_factor_cross_section(conn, factor_name, prev_date, universe)
curr_map = _fetch_factor_cross_section(conn, factor_name, curr_date, universe)
if not prev_map or not curr_map:
continue
prev_threshold = np.percentile(list(prev_map.values()), 80)
curr_threshold = np.percentile(list(curr_map.values()), 80)
prev_top = {code for code, value in prev_map.items() if value >= prev_threshold}
curr_top = {code for code, value in curr_map.items() if value >= curr_threshold}
if not prev_top and not curr_top:
continue
union = prev_top | curr_top
if not union:
continue
turnover = len(prev_top ^ curr_top) / len(union)
turnovers.append(turnover)
if turnovers:
return float(np.mean(turnovers))
return None