refactor factor evaluation and add sample size tracking

This commit is contained in:
Your Name 2025-10-10 21:47:00 +08:00
parent 43c70f3f7f
commit 8aa8efb651
6 changed files with 333 additions and 153 deletions

View File

@ -1,5 +1,5 @@
"""Factor performance evaluation utilities."""
from datetime import date, timedelta
from datetime import date
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
@ -7,12 +7,10 @@ from scipy import stats
from app.features.factors import (
DEFAULT_FACTORS,
FactorResult,
FactorSpec,
compute_factor_range,
lookup_factor_spec,
)
from app.utils.data_access import DataBroker
from app.utils.db import db_session
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
@ -29,6 +27,7 @@ class FactorPerformance:
self.return_spreads: List[float] = []
self.sharpe_ratio: Optional[float] = None
self.turnover_rate: Optional[float] = None
self.sample_size: int = 0
@property
def ic_mean(self) -> float:
@ -58,7 +57,8 @@ class FactorPerformance:
"ic_ir": self.ic_ir,
"rank_ic_mean": self.rank_ic_mean,
"sharpe_ratio": self.sharpe_ratio or 0.0,
"turnover_rate": self.turnover_rate or 0.0
"turnover_rate": self.turnover_rate or 0.0,
"sample_size": float(self.sample_size),
}
@ -80,132 +80,105 @@ def evaluate_factor(
因子表现评估结果
"""
performance = FactorPerformance(factor_name)
spec = lookup_factor_spec(factor_name)
factor_column = factor_name
# 导入进度状态模块
from app.ui.progress_state import factor_progress
if spec is None:
LOGGER.warning("未找到因子定义,仍尝试从数据库读取 factor=%s", factor_name, extra=LOG_EXTRA)
# 开始因子计算进度在异步线程中不直接访问factor_progress
# factor_progress.start_calculation(
# total_securities=len(universe) if universe else 0,
# message=f"开始评估因子 {factor_name}"
# )
normalized_universe = _normalize_universe(universe)
start_str = start_date.strftime("%Y%m%d")
end_str = end_date.strftime("%Y%m%d")
try:
spec = lookup_factor_spec(factor_name) or FactorSpec(factor_name, 0)
with db_session(read_only=True) as conn:
if not _has_factor_column(conn, factor_column):
LOGGER.warning("factors 表缺少列 %s,跳过评估", factor_column, extra=LOG_EXTRA)
return performance
trade_dates = _list_factor_dates(conn, start_str, end_str, normalized_universe)
factor_results = compute_factor_range(
start_date,
end_date,
factors=[spec],
ts_codes=universe,
skip_existing=True,
)
if not trade_dates:
LOGGER.info("指定区间内未找到可用因子数据 factor=%s", factor_name, extra=LOG_EXTRA)
return performance
# 因子计算完成在异步线程中不直接访问factor_progress
# factor_progress.complete_calculation(
# message=f"因子 {factor_name} 评估完成"
# )
usable_trade_dates: List[str] = []
except Exception as e:
# 因子计算失败在异步线程中不直接访问factor_progress
# factor_progress.complete_calculation(
# message=f"因子 {factor_name} 评估失败: {str(e)}",
# success=False
# )
raise
# 按日期分组
date_groups: Dict[date, List[FactorResult]] = {}
for result in factor_results:
if result.trade_date not in date_groups:
date_groups[result.trade_date] = []
date_groups[result.trade_date].append(result)
# 计算每日IC值和RankIC值
broker = DataBroker()
for curr_date, results in sorted(date_groups.items()):
next_date = curr_date + timedelta(days=1)
# 获取因子值和次日收益率
factor_values = []
next_returns = []
for result in results:
factor_val = result.values.get(factor_name)
if factor_val is None:
for trade_date_str in trade_dates:
with db_session(read_only=True) as conn:
factor_map = _fetch_factor_cross_section(conn, factor_column, trade_date_str, normalized_universe)
if not factor_map:
continue
next_trade = _next_trade_date(conn, trade_date_str)
if not next_trade:
continue
curr_close = _fetch_close_map(conn, trade_date_str, factor_map.keys())
next_close = _fetch_close_map(conn, next_trade, factor_map.keys())
# 获取次日收益率
next_close = broker.fetch_latest(
result.ts_code,
next_date.strftime("%Y%m%d"),
["daily.close"]
).get("daily.close")
factor_values: List[float] = []
returns: List[float] = []
for ts_code, value in factor_map.items():
curr = curr_close.get(ts_code)
nxt = next_close.get(ts_code)
if curr is None or nxt is None or curr <= 0:
continue
factor_values.append(value)
returns.append((nxt - curr) / curr)
curr_close = broker.fetch_latest(
result.ts_code,
curr_date.strftime("%Y%m%d"),
["daily.close"]
).get("daily.close")
if len(factor_values) < 20:
continue
if next_close and curr_close and curr_close > 0:
ret = (next_close - curr_close) / curr_close
factor_values.append(factor_val)
next_returns.append(ret)
values_array = np.array(factor_values, dtype=float)
returns_array = np.array(returns, dtype=float)
if np.ptp(values_array) <= 1e-9 or np.ptp(returns_array) <= 1e-9:
LOGGER.debug(
"因子/收益序列波动不足,跳过 date=%s span_factor=%.6f span_return=%.6f",
trade_date_str,
float(np.ptp(values_array)),
float(np.ptp(returns_array)),
extra=LOG_EXTRA,
)
continue
if len(factor_values) >= 20: # 需要足够多的样本
# 计算IC
ic, _ = stats.pearsonr(factor_values, next_returns)
performance.ic_series.append(ic)
try:
ic, _ = stats.pearsonr(values_array, returns_array)
rank_ic, _ = stats.spearmanr(values_array, returns_array)
except Exception as exc: # noqa: BLE001
LOGGER.debug("IC 计算失败 date=%s err=%s", trade_date_str, exc, extra=LOG_EXTRA)
continue
# 计算RankIC
rank_ic, _ = stats.spearmanr(factor_values, next_returns)
performance.rank_ic_series.append(rank_ic)
if not (np.isfinite(ic) and np.isfinite(rank_ic)):
LOGGER.debug(
"相关系数结果无效 date=%s ic=%s rank_ic=%s",
trade_date_str,
ic,
rank_ic,
extra=LOG_EXTRA,
)
continue
# 计算多空组合收益
sorted_pairs = sorted(zip(factor_values, next_returns),
key=lambda x: x[0])
n = len(sorted_pairs) // 5 # 五分位
if n > 0:
top_returns = [r for _, r in sorted_pairs[-n:]]
bottom_returns = [r for _, r in sorted_pairs[:n]]
spread = np.mean(top_returns) - np.mean(bottom_returns)
performance.return_spreads.append(spread)
performance.ic_series.append(ic)
performance.rank_ic_series.append(rank_ic)
usable_trade_dates.append(trade_date_str)
sorted_pairs = sorted(zip(values_array.tolist(), returns_array.tolist()), key=lambda item: item[0])
quantile = len(sorted_pairs) // 5
if quantile > 0:
top_returns = [ret for _, ret in sorted_pairs[-quantile:]]
bottom_returns = [ret for _, ret in sorted_pairs[:quantile]]
spread = float(np.mean(top_returns) - np.mean(bottom_returns))
performance.return_spreads.append(spread)
# 计算Sharpe比率
if performance.return_spreads:
annual_factor = np.sqrt(252) # 交易日数
returns_mean = np.mean(performance.return_spreads)
returns_std = np.std(performance.return_spreads)
returns_mean = float(np.mean(performance.return_spreads))
returns_std = float(np.std(performance.return_spreads))
if returns_std > 0:
performance.sharpe_ratio = returns_mean / returns_std * annual_factor
# 估算换手率
if factor_results:
dates = sorted(date_groups.keys())
turnovers = []
for i in range(1, len(dates)):
prev_results = date_groups[dates[i-1]]
curr_results = date_groups[dates[i]]
# 计算组合变化
prev_top = {r.ts_code for r in prev_results
if r.values.get(factor_name, float('-inf')) > np.percentile(
[res.values.get(factor_name, float('-inf'))
for res in prev_results], 80)}
curr_top = {r.ts_code for r in curr_results
if r.values.get(factor_name, float('-inf')) > np.percentile(
[res.values.get(factor_name, float('-inf'))
for res in curr_results], 80)}
# 计算换手率
if prev_top and curr_top:
turnover = len(prev_top ^ curr_top) / len(prev_top | curr_top)
turnovers.append(turnover)
if turnovers:
performance.turnover_rate = np.mean(turnovers)
performance.sharpe_ratio = returns_mean / returns_std * np.sqrt(252.0)
performance.sample_size = len(usable_trade_dates)
performance.turnover_rate = _estimate_turnover_rate(
factor_column,
usable_trade_dates,
normalized_universe,
)
return performance
@ -233,3 +206,136 @@ def combine_factors(
)
return FactorSpec(name, window)
def _normalize_universe(universe: Optional[Sequence[str]]) -> Optional[Tuple[str, ...]]:
if not universe:
return None
unique: Dict[str, None] = {}
for code in universe:
value = (code or "").strip().upper()
if value:
unique.setdefault(value, None)
return tuple(unique.keys()) if unique else None
def _has_factor_column(conn, column: str) -> bool:
rows = conn.execute("PRAGMA table_info(factors)").fetchall()
available = {row["name"] for row in rows}
return column in available
def _list_factor_dates(conn, start: str, end: str, universe: Optional[Tuple[str, ...]]) -> List[str]:
params: List[str] = [start, end]
query = (
"SELECT DISTINCT trade_date FROM factors "
"WHERE trade_date BETWEEN ? AND ?"
)
if universe:
placeholders = ",".join("?" for _ in universe)
query += f" AND ts_code IN ({placeholders})"
params.extend(universe)
query += " ORDER BY trade_date"
rows = conn.execute(query, params).fetchall()
return [row["trade_date"] for row in rows if row and row["trade_date"]]
def _fetch_factor_cross_section(
conn,
column: str,
trade_date: str,
universe: Optional[Tuple[str, ...]],
) -> Dict[str, float]:
params: List[str] = [trade_date]
query = f"SELECT ts_code, {column} AS value FROM factors WHERE trade_date = ? AND {column} IS NOT NULL"
if universe:
placeholders = ",".join("?" for _ in universe)
query += f" AND ts_code IN ({placeholders})"
params.extend(universe)
rows = conn.execute(query, params).fetchall()
result: Dict[str, float] = {}
for row in rows:
ts_code = row["ts_code"]
value = row["value"]
if ts_code is None or value is None:
continue
try:
numeric = float(value)
except (TypeError, ValueError):
continue
if not np.isfinite(numeric):
continue
result[ts_code] = numeric
return result
def _next_trade_date(conn, trade_date: str) -> Optional[str]:
row = conn.execute(
"SELECT MIN(trade_date) AS next_date FROM daily WHERE trade_date > ?",
(trade_date,),
).fetchone()
next_date = row["next_date"] if row else None
return next_date
def _fetch_close_map(conn, trade_date: str, codes: Sequence[str]) -> Dict[str, float]:
if not codes:
return {}
placeholders = ",".join("?" for _ in codes)
params = [trade_date, *codes]
rows = conn.execute(
f"""
SELECT ts_code, close
FROM daily
WHERE trade_date = ?
AND ts_code IN ({placeholders})
AND close IS NOT NULL
""",
params,
).fetchall()
result: Dict[str, float] = {}
for row in rows:
ts_code = row["ts_code"]
value = row["close"]
if ts_code is None or value is None:
continue
try:
result[ts_code] = float(value)
except (TypeError, ValueError):
continue
return result
def _estimate_turnover_rate(
factor_name: str,
trade_dates: Sequence[str],
universe: Optional[Tuple[str, ...]],
) -> Optional[float]:
if not trade_dates:
return None
turnovers: List[float] = []
for idx in range(1, len(trade_dates)):
prev_date = trade_dates[idx - 1]
curr_date = trade_dates[idx]
with db_session(read_only=True) as conn:
prev_map = _fetch_factor_cross_section(conn, factor_name, prev_date, universe)
curr_map = _fetch_factor_cross_section(conn, factor_name, curr_date, universe)
if not prev_map or not curr_map:
continue
prev_threshold = np.percentile(list(prev_map.values()), 80)
curr_threshold = np.percentile(list(curr_map.values()), 80)
prev_top = {code for code, value in prev_map.items() if value >= prev_threshold}
curr_top = {code for code, value in curr_map.items() if value >= curr_threshold}
if not prev_top and not curr_top:
continue
union = prev_top | curr_top
if not union:
continue
turnover = len(prev_top ^ curr_top) / len(union)
turnovers.append(turnover)
if turnovers:
return float(np.mean(turnovers))
return None

View File

@ -706,7 +706,9 @@ def _compute_security_factors(
"daily_basic.pe",
"daily_basic.pb",
"daily_basic.ps",
"daily_basic.turnover_rate",
"daily_basic.volume_ratio",
"daily.close",
"daily.amount",
"daily.vol",
"daily_basic.dv_ratio", # 股息率用于扩展因子

View File

@ -37,12 +37,14 @@ class FactorProgressState:
total_batches: 总批次数
"""
now = time.time()
normalized_total = max(total_securities, 0)
normalized_batches = max(total_batches, 1) if total_batches else 1
st.session_state.factor_progress.update({
'current': 0,
'total': max(total_securities, 0),
'total': normalized_total,
'percentage': 0.0,
'current_batch': 0,
'total_batches': max(total_batches, 0),
'total_batches': normalized_batches,
'status': 'running',
'message': '开始因子计算...',
'start_time': now,
@ -74,8 +76,9 @@ class FactorProgressState:
elapsed = 0.0
# 更新状态
clamped_current = max(0, min(current_securities, total)) if total > 0 else max(0, current_securities)
progress.update({
'current': current_securities,
'current': clamped_current,
'current_batch': current_batch,
'percentage': percentage,
'message': message or f'处理批次 {current_batch}/{progress["total_batches"] or 1}',
@ -95,9 +98,10 @@ class FactorProgressState:
elapsed = max(0.0, time.time() - start_time)
else:
elapsed = progress.get('elapsed_time', 0.0) or 0.0
total = progress.get('total', 0)
progress.update({
'current': progress.get('total', 0),
'percentage': 100.0 if progress.get('total', 0) else progress.get('percentage', 0.0),
'current': total,
'percentage': 100.0 if total else progress.get('percentage', 0.0),
'status': 'completed',
'message': message,
'elapsed_time': elapsed,

View File

@ -230,6 +230,25 @@ def render_factor_calculation() -> None:
help="如果勾选,将跳过数据库中已存在的因子计算结果"
)
st.markdown("##### 数据维护")
maintenance_col1, maintenance_col2 = st.columns([1, 2])
with maintenance_col1:
clear_confirm = st.checkbox("确认清空因子表", key="factor_clear_confirm")
with maintenance_col2:
if st.button("清空因子表数据", disabled=not clear_confirm):
try:
with db_session() as conn:
conn.execute("DELETE FROM factors")
st.session_state.pop('factor_calculation_results', None)
st.session_state.pop('factor_calculation_error', None)
factor_progress.reset()
st.success("因子表数据已清空。")
except Exception as exc: # noqa: BLE001
LOGGER.exception("清空因子表失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"清空因子表失败:{exc}")
finally:
st.session_state['factor_clear_confirm'] = False
# 5. 开始计算按钮
if st.button("开始计算因子", disabled=not selected_factors):
# 重置状态

View File

@ -203,6 +203,7 @@ def render_stock_evaluation() -> None:
"IC信息比率": performance.ic_ir,
"夏普比率": performance.sharpe_ratio,
"换手率": performance.turnover_rate,
"有效样本数": performance.sample_size,
})
st.session_state.evaluation_results = results
@ -251,6 +252,8 @@ def render_stock_evaluation() -> None:
display_df["换手率"] = display_df["换手率"].map(
lambda v: "N/A" if v is None else f"{v * 100:.1f}%"
)
if "有效样本数" in display_df:
display_df["有效样本数"] = display_df["有效样本数"].astype(int)
st.dataframe(
display_df,
hide_index=True,
@ -260,31 +263,64 @@ def render_stock_evaluation() -> None:
st.info("未产生任何因子评估结果。")
# 绘制IC均值分布
ic_means = result_df["IC均值"].astype(float).tolist() if not result_df.empty else []
factor_names = result_df["因子"].tolist() if not result_df.empty else []
ic_series = result_df["IC均值"].astype(float) if not result_df.empty else pd.Series(dtype=float)
if "有效样本数" in result_df:
sample_series = result_df["有效样本数"].astype(int)
ic_series = ic_series.where(sample_series > 0)
ic_means = ic_series.tolist()
chart_df = pd.DataFrame({
"因子": [r["因子"] for r in results],
"因子": factor_names,
"IC均值": ic_means
})
st.bar_chart(chart_df.set_index("因子"))
if not ic_means:
if not factor_names:
st.info("暂无足够的 IC 数据,无法生成股票评分。")
return
ic_array = np.array(ic_means, dtype=float)
usable_indices = [idx for idx, value in enumerate(ic_array) if np.isfinite(value)]
if not usable_indices:
st.info("所有因子 IC 均值均不可用,请先补充因子数据再评估。")
return
usable_factors = [factor_names[idx] for idx in usable_indices]
usable_ic = ic_array[usable_indices]
dropped_factors = [factor_names[idx] for idx, value in enumerate(ic_array) if not np.isfinite(value)]
if dropped_factors:
st.caption(f"已忽略缺少有效 IC 数据的因子:{', '.join(dropped_factors)}")
with st.spinner("正在生成股票评分..."):
if all(mean == 0 for mean in ic_means):
factor_weights = [1.0 / len(ic_means)] * len(ic_means)
LOGGER.info("所有因子IC均值均为零使用均匀权重", extra=LOG_EXTRA)
if np.all(np.abs(usable_ic) <= 1e-9):
factor_weights = np.full(usable_ic.shape, 1.0 / usable_ic.size, dtype=float)
LOGGER.info("因子IC均值均为零使用均匀权重", extra=LOG_EXTRA)
else:
abs_sum = sum(abs(m) for m in ic_means) or 1.0
factor_weights = [m / abs_sum for m in ic_means]
LOGGER.info("使用IC均值作为权重: %s", factor_weights, extra=LOG_EXTRA)
abs_sum = float(np.sum(np.abs(usable_ic)))
if abs_sum <= 1e-9:
factor_weights = np.full(usable_ic.shape, 1.0 / usable_ic.size, dtype=float)
LOGGER.info("有效因子IC均值绝对和过小使用均匀权重", extra=LOG_EXTRA)
else:
factor_weights = usable_ic / abs_sum
LOGGER.info("使用IC均值作为权重: %s", factor_weights.tolist(), extra=LOG_EXTRA)
weight_mask = np.abs(factor_weights) > 1e-6
filtered_factors = [name for name, flag in zip(usable_factors, weight_mask) if flag]
filtered_weights = [float(weight) for weight, flag in zip(factor_weights, weight_mask) if flag]
if not filtered_factors:
st.info("因子权重有效值均为零,无法生成股票评分。")
return
if len(filtered_factors) < len(usable_factors):
dropped_names = [name for name, flag in zip(usable_factors, weight_mask) if not flag]
LOGGER.info("已忽略权重为零的因子:%s", dropped_names, extra=LOG_EXTRA)
scores = _calculate_stock_scores(
universe,
selected_factors,
filtered_factors,
end_date,
factor_weights
filtered_weights,
)
if scores:
@ -319,6 +355,18 @@ def _calculate_stock_scores(
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "stock_evaluation"}
if not factors:
LOGGER.warning("因子列表为空,无法计算股票评分", extra=LOG_EXTRA)
return []
if len(factors) != len(factor_weights):
LOGGER.error(
"因子数量与权重数量不一致 factors=%s weights=%s",
len(factors),
len(factor_weights),
extra=LOG_EXTRA,
)
return []
broker = DataBroker()
trade_date_str = eval_date.strftime("%Y%m%d")

View File

@ -1413,20 +1413,21 @@ class DataBroker:
try:
# 获取股票基本信息
info = self.fetch_latest(
raw_info = self.fetch_latest(
ts_code=ts_code,
trade_date=trade_date,
fields=["stock_basic.name", "stock_basic.industry"]
)
if not info:
if not raw_info:
return None
# 添加股票代码
result = {"ts_code": ts_code}
result.update(info)
return result
info: Dict[str, Any] = {"ts_code": ts_code}
for key, value in raw_info.items():
if key == "ts_code":
continue
alias = key.split(".", 1)[-1] if isinstance(key, str) and "." in key else key
info[alias] = value
return info
except Exception as exc:
LOGGER.debug(
"获取股票信息失败 ts_code=%s err=%s",