refactor factor evaluation and add sample size tracking

This commit is contained in:
Your Name 2025-10-10 21:47:00 +08:00
parent 43c70f3f7f
commit 8aa8efb651
6 changed files with 333 additions and 153 deletions

View File

@ -1,5 +1,5 @@
"""Factor performance evaluation utilities.""" """Factor performance evaluation utilities."""
from datetime import date, timedelta from datetime import date
from typing import Dict, List, Optional, Sequence, Tuple from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np import numpy as np
@ -7,12 +7,10 @@ from scipy import stats
from app.features.factors import ( from app.features.factors import (
DEFAULT_FACTORS, DEFAULT_FACTORS,
FactorResult,
FactorSpec, FactorSpec,
compute_factor_range,
lookup_factor_spec, lookup_factor_spec,
) )
from app.utils.data_access import DataBroker from app.utils.db import db_session
from app.utils.logging import get_logger from app.utils.logging import get_logger
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
@ -29,6 +27,7 @@ class FactorPerformance:
self.return_spreads: List[float] = [] self.return_spreads: List[float] = []
self.sharpe_ratio: Optional[float] = None self.sharpe_ratio: Optional[float] = None
self.turnover_rate: Optional[float] = None self.turnover_rate: Optional[float] = None
self.sample_size: int = 0
@property @property
def ic_mean(self) -> float: def ic_mean(self) -> float:
@ -58,7 +57,8 @@ class FactorPerformance:
"ic_ir": self.ic_ir, "ic_ir": self.ic_ir,
"rank_ic_mean": self.rank_ic_mean, "rank_ic_mean": self.rank_ic_mean,
"sharpe_ratio": self.sharpe_ratio or 0.0, "sharpe_ratio": self.sharpe_ratio or 0.0,
"turnover_rate": self.turnover_rate or 0.0 "turnover_rate": self.turnover_rate or 0.0,
"sample_size": float(self.sample_size),
} }
@ -80,132 +80,105 @@ def evaluate_factor(
因子表现评估结果 因子表现评估结果
""" """
performance = FactorPerformance(factor_name) performance = FactorPerformance(factor_name)
spec = lookup_factor_spec(factor_name)
factor_column = factor_name
# 导入进度状态模块 if spec is None:
from app.ui.progress_state import factor_progress LOGGER.warning("未找到因子定义,仍尝试从数据库读取 factor=%s", factor_name, extra=LOG_EXTRA)
# 开始因子计算进度在异步线程中不直接访问factor_progress normalized_universe = _normalize_universe(universe)
# factor_progress.start_calculation( start_str = start_date.strftime("%Y%m%d")
# total_securities=len(universe) if universe else 0, end_str = end_date.strftime("%Y%m%d")
# message=f"开始评估因子 {factor_name}"
# )
try: with db_session(read_only=True) as conn:
spec = lookup_factor_spec(factor_name) or FactorSpec(factor_name, 0) if not _has_factor_column(conn, factor_column):
LOGGER.warning("factors 表缺少列 %s,跳过评估", factor_column, extra=LOG_EXTRA)
return performance
trade_dates = _list_factor_dates(conn, start_str, end_str, normalized_universe)
factor_results = compute_factor_range( if not trade_dates:
start_date, LOGGER.info("指定区间内未找到可用因子数据 factor=%s", factor_name, extra=LOG_EXTRA)
end_date, return performance
factors=[spec],
ts_codes=universe,
skip_existing=True,
)
# 因子计算完成在异步线程中不直接访问factor_progress usable_trade_dates: List[str] = []
# factor_progress.complete_calculation(
# message=f"因子 {factor_name} 评估完成"
# )
except Exception as e: for trade_date_str in trade_dates:
# 因子计算失败在异步线程中不直接访问factor_progress with db_session(read_only=True) as conn:
# factor_progress.complete_calculation( factor_map = _fetch_factor_cross_section(conn, factor_column, trade_date_str, normalized_universe)
# message=f"因子 {factor_name} 评估失败: {str(e)}", if not factor_map:
# success=False
# )
raise
# 按日期分组
date_groups: Dict[date, List[FactorResult]] = {}
for result in factor_results:
if result.trade_date not in date_groups:
date_groups[result.trade_date] = []
date_groups[result.trade_date].append(result)
# 计算每日IC值和RankIC值
broker = DataBroker()
for curr_date, results in sorted(date_groups.items()):
next_date = curr_date + timedelta(days=1)
# 获取因子值和次日收益率
factor_values = []
next_returns = []
for result in results:
factor_val = result.values.get(factor_name)
if factor_val is None:
continue continue
next_trade = _next_trade_date(conn, trade_date_str)
if not next_trade:
continue
curr_close = _fetch_close_map(conn, trade_date_str, factor_map.keys())
next_close = _fetch_close_map(conn, next_trade, factor_map.keys())
# 获取次日收益率 factor_values: List[float] = []
next_close = broker.fetch_latest( returns: List[float] = []
result.ts_code, for ts_code, value in factor_map.items():
next_date.strftime("%Y%m%d"), curr = curr_close.get(ts_code)
["daily.close"] nxt = next_close.get(ts_code)
).get("daily.close") if curr is None or nxt is None or curr <= 0:
continue
factor_values.append(value)
returns.append((nxt - curr) / curr)
curr_close = broker.fetch_latest( if len(factor_values) < 20:
result.ts_code, continue
curr_date.strftime("%Y%m%d"),
["daily.close"]
).get("daily.close")
if next_close and curr_close and curr_close > 0: values_array = np.array(factor_values, dtype=float)
ret = (next_close - curr_close) / curr_close returns_array = np.array(returns, dtype=float)
factor_values.append(factor_val) if np.ptp(values_array) <= 1e-9 or np.ptp(returns_array) <= 1e-9:
next_returns.append(ret) LOGGER.debug(
"因子/收益序列波动不足,跳过 date=%s span_factor=%.6f span_return=%.6f",
trade_date_str,
float(np.ptp(values_array)),
float(np.ptp(returns_array)),
extra=LOG_EXTRA,
)
continue
if len(factor_values) >= 20: # 需要足够多的样本 try:
# 计算IC ic, _ = stats.pearsonr(values_array, returns_array)
ic, _ = stats.pearsonr(factor_values, next_returns) rank_ic, _ = stats.spearmanr(values_array, returns_array)
performance.ic_series.append(ic) except Exception as exc: # noqa: BLE001
LOGGER.debug("IC 计算失败 date=%s err=%s", trade_date_str, exc, extra=LOG_EXTRA)
continue
# 计算RankIC if not (np.isfinite(ic) and np.isfinite(rank_ic)):
rank_ic, _ = stats.spearmanr(factor_values, next_returns) LOGGER.debug(
performance.rank_ic_series.append(rank_ic) "相关系数结果无效 date=%s ic=%s rank_ic=%s",
trade_date_str,
ic,
rank_ic,
extra=LOG_EXTRA,
)
continue
# 计算多空组合收益 performance.ic_series.append(ic)
sorted_pairs = sorted(zip(factor_values, next_returns), performance.rank_ic_series.append(rank_ic)
key=lambda x: x[0]) usable_trade_dates.append(trade_date_str)
n = len(sorted_pairs) // 5 # 五分位
if n > 0: sorted_pairs = sorted(zip(values_array.tolist(), returns_array.tolist()), key=lambda item: item[0])
top_returns = [r for _, r in sorted_pairs[-n:]] quantile = len(sorted_pairs) // 5
bottom_returns = [r for _, r in sorted_pairs[:n]] if quantile > 0:
spread = np.mean(top_returns) - np.mean(bottom_returns) top_returns = [ret for _, ret in sorted_pairs[-quantile:]]
performance.return_spreads.append(spread) bottom_returns = [ret for _, ret in sorted_pairs[:quantile]]
spread = float(np.mean(top_returns) - np.mean(bottom_returns))
performance.return_spreads.append(spread)
# 计算Sharpe比率
if performance.return_spreads: if performance.return_spreads:
annual_factor = np.sqrt(252) # 交易日数 returns_mean = float(np.mean(performance.return_spreads))
returns_mean = np.mean(performance.return_spreads) returns_std = float(np.std(performance.return_spreads))
returns_std = np.std(performance.return_spreads)
if returns_std > 0: if returns_std > 0:
performance.sharpe_ratio = returns_mean / returns_std * annual_factor performance.sharpe_ratio = returns_mean / returns_std * np.sqrt(252.0)
# 估算换手率
if factor_results:
dates = sorted(date_groups.keys())
turnovers = []
for i in range(1, len(dates)):
prev_results = date_groups[dates[i-1]]
curr_results = date_groups[dates[i]]
# 计算组合变化
prev_top = {r.ts_code for r in prev_results
if r.values.get(factor_name, float('-inf')) > np.percentile(
[res.values.get(factor_name, float('-inf'))
for res in prev_results], 80)}
curr_top = {r.ts_code for r in curr_results
if r.values.get(factor_name, float('-inf')) > np.percentile(
[res.values.get(factor_name, float('-inf'))
for res in curr_results], 80)}
# 计算换手率
if prev_top and curr_top:
turnover = len(prev_top ^ curr_top) / len(prev_top | curr_top)
turnovers.append(turnover)
if turnovers:
performance.turnover_rate = np.mean(turnovers)
performance.sample_size = len(usable_trade_dates)
performance.turnover_rate = _estimate_turnover_rate(
factor_column,
usable_trade_dates,
normalized_universe,
)
return performance return performance
@ -233,3 +206,136 @@ def combine_factors(
) )
return FactorSpec(name, window) return FactorSpec(name, window)
def _normalize_universe(universe: Optional[Sequence[str]]) -> Optional[Tuple[str, ...]]:
if not universe:
return None
unique: Dict[str, None] = {}
for code in universe:
value = (code or "").strip().upper()
if value:
unique.setdefault(value, None)
return tuple(unique.keys()) if unique else None
def _has_factor_column(conn, column: str) -> bool:
rows = conn.execute("PRAGMA table_info(factors)").fetchall()
available = {row["name"] for row in rows}
return column in available
def _list_factor_dates(conn, start: str, end: str, universe: Optional[Tuple[str, ...]]) -> List[str]:
params: List[str] = [start, end]
query = (
"SELECT DISTINCT trade_date FROM factors "
"WHERE trade_date BETWEEN ? AND ?"
)
if universe:
placeholders = ",".join("?" for _ in universe)
query += f" AND ts_code IN ({placeholders})"
params.extend(universe)
query += " ORDER BY trade_date"
rows = conn.execute(query, params).fetchall()
return [row["trade_date"] for row in rows if row and row["trade_date"]]
def _fetch_factor_cross_section(
conn,
column: str,
trade_date: str,
universe: Optional[Tuple[str, ...]],
) -> Dict[str, float]:
params: List[str] = [trade_date]
query = f"SELECT ts_code, {column} AS value FROM factors WHERE trade_date = ? AND {column} IS NOT NULL"
if universe:
placeholders = ",".join("?" for _ in universe)
query += f" AND ts_code IN ({placeholders})"
params.extend(universe)
rows = conn.execute(query, params).fetchall()
result: Dict[str, float] = {}
for row in rows:
ts_code = row["ts_code"]
value = row["value"]
if ts_code is None or value is None:
continue
try:
numeric = float(value)
except (TypeError, ValueError):
continue
if not np.isfinite(numeric):
continue
result[ts_code] = numeric
return result
def _next_trade_date(conn, trade_date: str) -> Optional[str]:
row = conn.execute(
"SELECT MIN(trade_date) AS next_date FROM daily WHERE trade_date > ?",
(trade_date,),
).fetchone()
next_date = row["next_date"] if row else None
return next_date
def _fetch_close_map(conn, trade_date: str, codes: Sequence[str]) -> Dict[str, float]:
if not codes:
return {}
placeholders = ",".join("?" for _ in codes)
params = [trade_date, *codes]
rows = conn.execute(
f"""
SELECT ts_code, close
FROM daily
WHERE trade_date = ?
AND ts_code IN ({placeholders})
AND close IS NOT NULL
""",
params,
).fetchall()
result: Dict[str, float] = {}
for row in rows:
ts_code = row["ts_code"]
value = row["close"]
if ts_code is None or value is None:
continue
try:
result[ts_code] = float(value)
except (TypeError, ValueError):
continue
return result
def _estimate_turnover_rate(
factor_name: str,
trade_dates: Sequence[str],
universe: Optional[Tuple[str, ...]],
) -> Optional[float]:
if not trade_dates:
return None
turnovers: List[float] = []
for idx in range(1, len(trade_dates)):
prev_date = trade_dates[idx - 1]
curr_date = trade_dates[idx]
with db_session(read_only=True) as conn:
prev_map = _fetch_factor_cross_section(conn, factor_name, prev_date, universe)
curr_map = _fetch_factor_cross_section(conn, factor_name, curr_date, universe)
if not prev_map or not curr_map:
continue
prev_threshold = np.percentile(list(prev_map.values()), 80)
curr_threshold = np.percentile(list(curr_map.values()), 80)
prev_top = {code for code, value in prev_map.items() if value >= prev_threshold}
curr_top = {code for code, value in curr_map.items() if value >= curr_threshold}
if not prev_top and not curr_top:
continue
union = prev_top | curr_top
if not union:
continue
turnover = len(prev_top ^ curr_top) / len(union)
turnovers.append(turnover)
if turnovers:
return float(np.mean(turnovers))
return None

View File

@ -706,7 +706,9 @@ def _compute_security_factors(
"daily_basic.pe", "daily_basic.pe",
"daily_basic.pb", "daily_basic.pb",
"daily_basic.ps", "daily_basic.ps",
"daily_basic.turnover_rate",
"daily_basic.volume_ratio", "daily_basic.volume_ratio",
"daily.close",
"daily.amount", "daily.amount",
"daily.vol", "daily.vol",
"daily_basic.dv_ratio", # 股息率用于扩展因子 "daily_basic.dv_ratio", # 股息率用于扩展因子

View File

@ -37,12 +37,14 @@ class FactorProgressState:
total_batches: 总批次数 total_batches: 总批次数
""" """
now = time.time() now = time.time()
normalized_total = max(total_securities, 0)
normalized_batches = max(total_batches, 1) if total_batches else 1
st.session_state.factor_progress.update({ st.session_state.factor_progress.update({
'current': 0, 'current': 0,
'total': max(total_securities, 0), 'total': normalized_total,
'percentage': 0.0, 'percentage': 0.0,
'current_batch': 0, 'current_batch': 0,
'total_batches': max(total_batches, 0), 'total_batches': normalized_batches,
'status': 'running', 'status': 'running',
'message': '开始因子计算...', 'message': '开始因子计算...',
'start_time': now, 'start_time': now,
@ -74,8 +76,9 @@ class FactorProgressState:
elapsed = 0.0 elapsed = 0.0
# 更新状态 # 更新状态
clamped_current = max(0, min(current_securities, total)) if total > 0 else max(0, current_securities)
progress.update({ progress.update({
'current': current_securities, 'current': clamped_current,
'current_batch': current_batch, 'current_batch': current_batch,
'percentage': percentage, 'percentage': percentage,
'message': message or f'处理批次 {current_batch}/{progress["total_batches"] or 1}', 'message': message or f'处理批次 {current_batch}/{progress["total_batches"] or 1}',
@ -95,9 +98,10 @@ class FactorProgressState:
elapsed = max(0.0, time.time() - start_time) elapsed = max(0.0, time.time() - start_time)
else: else:
elapsed = progress.get('elapsed_time', 0.0) or 0.0 elapsed = progress.get('elapsed_time', 0.0) or 0.0
total = progress.get('total', 0)
progress.update({ progress.update({
'current': progress.get('total', 0), 'current': total,
'percentage': 100.0 if progress.get('total', 0) else progress.get('percentage', 0.0), 'percentage': 100.0 if total else progress.get('percentage', 0.0),
'status': 'completed', 'status': 'completed',
'message': message, 'message': message,
'elapsed_time': elapsed, 'elapsed_time': elapsed,

View File

@ -230,6 +230,25 @@ def render_factor_calculation() -> None:
help="如果勾选,将跳过数据库中已存在的因子计算结果" help="如果勾选,将跳过数据库中已存在的因子计算结果"
) )
st.markdown("##### 数据维护")
maintenance_col1, maintenance_col2 = st.columns([1, 2])
with maintenance_col1:
clear_confirm = st.checkbox("确认清空因子表", key="factor_clear_confirm")
with maintenance_col2:
if st.button("清空因子表数据", disabled=not clear_confirm):
try:
with db_session() as conn:
conn.execute("DELETE FROM factors")
st.session_state.pop('factor_calculation_results', None)
st.session_state.pop('factor_calculation_error', None)
factor_progress.reset()
st.success("因子表数据已清空。")
except Exception as exc: # noqa: BLE001
LOGGER.exception("清空因子表失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"清空因子表失败:{exc}")
finally:
st.session_state['factor_clear_confirm'] = False
# 5. 开始计算按钮 # 5. 开始计算按钮
if st.button("开始计算因子", disabled=not selected_factors): if st.button("开始计算因子", disabled=not selected_factors):
# 重置状态 # 重置状态

View File

@ -203,6 +203,7 @@ def render_stock_evaluation() -> None:
"IC信息比率": performance.ic_ir, "IC信息比率": performance.ic_ir,
"夏普比率": performance.sharpe_ratio, "夏普比率": performance.sharpe_ratio,
"换手率": performance.turnover_rate, "换手率": performance.turnover_rate,
"有效样本数": performance.sample_size,
}) })
st.session_state.evaluation_results = results st.session_state.evaluation_results = results
@ -251,6 +252,8 @@ def render_stock_evaluation() -> None:
display_df["换手率"] = display_df["换手率"].map( display_df["换手率"] = display_df["换手率"].map(
lambda v: "N/A" if v is None else f"{v * 100:.1f}%" lambda v: "N/A" if v is None else f"{v * 100:.1f}%"
) )
if "有效样本数" in display_df:
display_df["有效样本数"] = display_df["有效样本数"].astype(int)
st.dataframe( st.dataframe(
display_df, display_df,
hide_index=True, hide_index=True,
@ -260,31 +263,64 @@ def render_stock_evaluation() -> None:
st.info("未产生任何因子评估结果。") st.info("未产生任何因子评估结果。")
# 绘制IC均值分布 # 绘制IC均值分布
ic_means = result_df["IC均值"].astype(float).tolist() if not result_df.empty else [] factor_names = result_df["因子"].tolist() if not result_df.empty else []
ic_series = result_df["IC均值"].astype(float) if not result_df.empty else pd.Series(dtype=float)
if "有效样本数" in result_df:
sample_series = result_df["有效样本数"].astype(int)
ic_series = ic_series.where(sample_series > 0)
ic_means = ic_series.tolist()
chart_df = pd.DataFrame({ chart_df = pd.DataFrame({
"因子": [r["因子"] for r in results], "因子": factor_names,
"IC均值": ic_means "IC均值": ic_means
}) })
st.bar_chart(chart_df.set_index("因子")) st.bar_chart(chart_df.set_index("因子"))
if not ic_means: if not factor_names:
st.info("暂无足够的 IC 数据,无法生成股票评分。") st.info("暂无足够的 IC 数据,无法生成股票评分。")
return return
ic_array = np.array(ic_means, dtype=float)
usable_indices = [idx for idx, value in enumerate(ic_array) if np.isfinite(value)]
if not usable_indices:
st.info("所有因子 IC 均值均不可用,请先补充因子数据再评估。")
return
usable_factors = [factor_names[idx] for idx in usable_indices]
usable_ic = ic_array[usable_indices]
dropped_factors = [factor_names[idx] for idx, value in enumerate(ic_array) if not np.isfinite(value)]
if dropped_factors:
st.caption(f"已忽略缺少有效 IC 数据的因子:{', '.join(dropped_factors)}")
with st.spinner("正在生成股票评分..."): with st.spinner("正在生成股票评分..."):
if all(mean == 0 for mean in ic_means): if np.all(np.abs(usable_ic) <= 1e-9):
factor_weights = [1.0 / len(ic_means)] * len(ic_means) factor_weights = np.full(usable_ic.shape, 1.0 / usable_ic.size, dtype=float)
LOGGER.info("所有因子IC均值均为零使用均匀权重", extra=LOG_EXTRA) LOGGER.info("因子IC均值均为零使用均匀权重", extra=LOG_EXTRA)
else: else:
abs_sum = sum(abs(m) for m in ic_means) or 1.0 abs_sum = float(np.sum(np.abs(usable_ic)))
factor_weights = [m / abs_sum for m in ic_means] if abs_sum <= 1e-9:
LOGGER.info("使用IC均值作为权重: %s", factor_weights, extra=LOG_EXTRA) factor_weights = np.full(usable_ic.shape, 1.0 / usable_ic.size, dtype=float)
LOGGER.info("有效因子IC均值绝对和过小使用均匀权重", extra=LOG_EXTRA)
else:
factor_weights = usable_ic / abs_sum
LOGGER.info("使用IC均值作为权重: %s", factor_weights.tolist(), extra=LOG_EXTRA)
weight_mask = np.abs(factor_weights) > 1e-6
filtered_factors = [name for name, flag in zip(usable_factors, weight_mask) if flag]
filtered_weights = [float(weight) for weight, flag in zip(factor_weights, weight_mask) if flag]
if not filtered_factors:
st.info("因子权重有效值均为零,无法生成股票评分。")
return
if len(filtered_factors) < len(usable_factors):
dropped_names = [name for name, flag in zip(usable_factors, weight_mask) if not flag]
LOGGER.info("已忽略权重为零的因子:%s", dropped_names, extra=LOG_EXTRA)
scores = _calculate_stock_scores( scores = _calculate_stock_scores(
universe, universe,
selected_factors, filtered_factors,
end_date, end_date,
factor_weights filtered_weights,
) )
if scores: if scores:
@ -319,6 +355,18 @@ def _calculate_stock_scores(
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "stock_evaluation"} LOG_EXTRA = {"stage": "stock_evaluation"}
if not factors:
LOGGER.warning("因子列表为空,无法计算股票评分", extra=LOG_EXTRA)
return []
if len(factors) != len(factor_weights):
LOGGER.error(
"因子数量与权重数量不一致 factors=%s weights=%s",
len(factors),
len(factor_weights),
extra=LOG_EXTRA,
)
return []
broker = DataBroker() broker = DataBroker()
trade_date_str = eval_date.strftime("%Y%m%d") trade_date_str = eval_date.strftime("%Y%m%d")

View File

@ -1413,20 +1413,21 @@ class DataBroker:
try: try:
# 获取股票基本信息 # 获取股票基本信息
info = self.fetch_latest( raw_info = self.fetch_latest(
ts_code=ts_code, ts_code=ts_code,
trade_date=trade_date, trade_date=trade_date,
fields=["stock_basic.name", "stock_basic.industry"] fields=["stock_basic.name", "stock_basic.industry"]
) )
if not raw_info:
if not info:
return None return None
# 添加股票代码 info: Dict[str, Any] = {"ts_code": ts_code}
result = {"ts_code": ts_code} for key, value in raw_info.items():
result.update(info) if key == "ts_code":
continue
return result alias = key.split(".", 1)[-1] if isinstance(key, str) and "." in key else key
info[alias] = value
return info
except Exception as exc: except Exception as exc:
LOGGER.debug( LOGGER.debug(
"获取股票信息失败 ts_code=%s err=%s", "获取股票信息失败 ts_code=%s err=%s",