From 8aa8efb65193a02e753673d50482cf4b29822261 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 10 Oct 2025 21:47:00 +0800 Subject: [PATCH] refactor factor evaluation and add sample size tracking --- app/features/evaluation.py | 360 +++++++++++++++++++---------- app/features/factors.py | 2 + app/ui/progress_state.py | 14 +- app/ui/views/factor_calculation.py | 21 +- app/ui/views/stock_eval.py | 70 +++++- app/utils/data_access.py | 19 +- 6 files changed, 333 insertions(+), 153 deletions(-) diff --git a/app/features/evaluation.py b/app/features/evaluation.py index 3a83d11..35bee02 100644 --- a/app/features/evaluation.py +++ b/app/features/evaluation.py @@ -1,5 +1,5 @@ """Factor performance evaluation utilities.""" -from datetime import date, timedelta +from datetime import date from typing import Dict, List, Optional, Sequence, Tuple import numpy as np @@ -7,12 +7,10 @@ from scipy import stats from app.features.factors import ( DEFAULT_FACTORS, - FactorResult, FactorSpec, - compute_factor_range, lookup_factor_spec, ) -from app.utils.data_access import DataBroker +from app.utils.db import db_session from app.utils.logging import get_logger LOGGER = get_logger(__name__) @@ -29,6 +27,7 @@ class FactorPerformance: self.return_spreads: List[float] = [] self.sharpe_ratio: Optional[float] = None self.turnover_rate: Optional[float] = None + self.sample_size: int = 0 @property def ic_mean(self) -> float: @@ -58,7 +57,8 @@ class FactorPerformance: "ic_ir": self.ic_ir, "rank_ic_mean": self.rank_ic_mean, "sharpe_ratio": self.sharpe_ratio or 0.0, - "turnover_rate": self.turnover_rate or 0.0 + "turnover_rate": self.turnover_rate or 0.0, + "sample_size": float(self.sample_size), } @@ -80,132 +80,105 @@ def evaluate_factor( 因子表现评估结果 """ performance = FactorPerformance(factor_name) - - # 导入进度状态模块 - from app.ui.progress_state import factor_progress - - # 开始因子计算进度(在异步线程中不直接访问factor_progress) - # factor_progress.start_calculation( - # total_securities=len(universe) if universe else 0, - # message=f"开始评估因子 {factor_name}" - # ) - - try: - spec = lookup_factor_spec(factor_name) or FactorSpec(factor_name, 0) + spec = lookup_factor_spec(factor_name) + factor_column = factor_name - factor_results = compute_factor_range( - start_date, - end_date, - factors=[spec], - ts_codes=universe, - skip_existing=True, - ) - - # 因子计算完成(在异步线程中不直接访问factor_progress) - # factor_progress.complete_calculation( - # message=f"因子 {factor_name} 评估完成" - # ) - - except Exception as e: - # 因子计算失败(在异步线程中不直接访问factor_progress) - # factor_progress.complete_calculation( - # message=f"因子 {factor_name} 评估失败: {str(e)}", - # success=False - # ) - raise - - # 按日期分组 - date_groups: Dict[date, List[FactorResult]] = {} - for result in factor_results: - if result.trade_date not in date_groups: - date_groups[result.trade_date] = [] - date_groups[result.trade_date].append(result) - - # 计算每日IC值和RankIC值 - broker = DataBroker() - for curr_date, results in sorted(date_groups.items()): - next_date = curr_date + timedelta(days=1) - - # 获取因子值和次日收益率 - factor_values = [] - next_returns = [] - - for result in results: - factor_val = result.values.get(factor_name) - if factor_val is None: + if spec is None: + LOGGER.warning("未找到因子定义,仍尝试从数据库读取 factor=%s", factor_name, extra=LOG_EXTRA) + + normalized_universe = _normalize_universe(universe) + start_str = start_date.strftime("%Y%m%d") + end_str = end_date.strftime("%Y%m%d") + + with db_session(read_only=True) as conn: + if not _has_factor_column(conn, factor_column): + LOGGER.warning("factors 表缺少列 %s,跳过评估", factor_column, extra=LOG_EXTRA) + return performance + trade_dates = _list_factor_dates(conn, start_str, end_str, normalized_universe) + + if not trade_dates: + LOGGER.info("指定区间内未找到可用因子数据 factor=%s", factor_name, extra=LOG_EXTRA) + return performance + + usable_trade_dates: List[str] = [] + + for trade_date_str in trade_dates: + with db_session(read_only=True) as conn: + factor_map = _fetch_factor_cross_section(conn, factor_column, trade_date_str, normalized_universe) + if not factor_map: continue - - # 获取次日收益率 - next_close = broker.fetch_latest( - result.ts_code, - next_date.strftime("%Y%m%d"), - ["daily.close"] - ).get("daily.close") - - curr_close = broker.fetch_latest( - result.ts_code, - curr_date.strftime("%Y%m%d"), - ["daily.close"] - ).get("daily.close") - - if next_close and curr_close and curr_close > 0: - ret = (next_close - curr_close) / curr_close - factor_values.append(factor_val) - next_returns.append(ret) - - if len(factor_values) >= 20: # 需要足够多的样本 - # 计算IC - ic, _ = stats.pearsonr(factor_values, next_returns) - performance.ic_series.append(ic) - - # 计算RankIC - rank_ic, _ = stats.spearmanr(factor_values, next_returns) - performance.rank_ic_series.append(rank_ic) - - # 计算多空组合收益 - sorted_pairs = sorted(zip(factor_values, next_returns), - key=lambda x: x[0]) - n = len(sorted_pairs) // 5 # 五分位 - if n > 0: - top_returns = [r for _, r in sorted_pairs[-n:]] - bottom_returns = [r for _, r in sorted_pairs[:n]] - spread = np.mean(top_returns) - np.mean(bottom_returns) - performance.return_spreads.append(spread) - - # 计算Sharpe比率 + next_trade = _next_trade_date(conn, trade_date_str) + if not next_trade: + continue + curr_close = _fetch_close_map(conn, trade_date_str, factor_map.keys()) + next_close = _fetch_close_map(conn, next_trade, factor_map.keys()) + + factor_values: List[float] = [] + returns: List[float] = [] + for ts_code, value in factor_map.items(): + curr = curr_close.get(ts_code) + nxt = next_close.get(ts_code) + if curr is None or nxt is None or curr <= 0: + continue + factor_values.append(value) + returns.append((nxt - curr) / curr) + + if len(factor_values) < 20: + continue + + values_array = np.array(factor_values, dtype=float) + returns_array = np.array(returns, dtype=float) + if np.ptp(values_array) <= 1e-9 or np.ptp(returns_array) <= 1e-9: + LOGGER.debug( + "因子/收益序列波动不足,跳过 date=%s span_factor=%.6f span_return=%.6f", + trade_date_str, + float(np.ptp(values_array)), + float(np.ptp(returns_array)), + extra=LOG_EXTRA, + ) + continue + + try: + ic, _ = stats.pearsonr(values_array, returns_array) + rank_ic, _ = stats.spearmanr(values_array, returns_array) + except Exception as exc: # noqa: BLE001 + LOGGER.debug("IC 计算失败 date=%s err=%s", trade_date_str, exc, extra=LOG_EXTRA) + continue + + if not (np.isfinite(ic) and np.isfinite(rank_ic)): + LOGGER.debug( + "相关系数结果无效 date=%s ic=%s rank_ic=%s", + trade_date_str, + ic, + rank_ic, + extra=LOG_EXTRA, + ) + continue + + performance.ic_series.append(ic) + performance.rank_ic_series.append(rank_ic) + usable_trade_dates.append(trade_date_str) + + sorted_pairs = sorted(zip(values_array.tolist(), returns_array.tolist()), key=lambda item: item[0]) + quantile = len(sorted_pairs) // 5 + if quantile > 0: + top_returns = [ret for _, ret in sorted_pairs[-quantile:]] + bottom_returns = [ret for _, ret in sorted_pairs[:quantile]] + spread = float(np.mean(top_returns) - np.mean(bottom_returns)) + performance.return_spreads.append(spread) + if performance.return_spreads: - annual_factor = np.sqrt(252) # 交易日数 - returns_mean = np.mean(performance.return_spreads) - returns_std = np.std(performance.return_spreads) + returns_mean = float(np.mean(performance.return_spreads)) + returns_std = float(np.std(performance.return_spreads)) if returns_std > 0: - performance.sharpe_ratio = returns_mean / returns_std * annual_factor - - # 估算换手率 - if factor_results: - dates = sorted(date_groups.keys()) - turnovers = [] - for i in range(1, len(dates)): - prev_results = date_groups[dates[i-1]] - curr_results = date_groups[dates[i]] - - # 计算组合变化 - prev_top = {r.ts_code for r in prev_results - if r.values.get(factor_name, float('-inf')) > np.percentile( - [res.values.get(factor_name, float('-inf')) - for res in prev_results], 80)} - curr_top = {r.ts_code for r in curr_results - if r.values.get(factor_name, float('-inf')) > np.percentile( - [res.values.get(factor_name, float('-inf')) - for res in curr_results], 80)} - - # 计算换手率 - if prev_top and curr_top: - turnover = len(prev_top ^ curr_top) / len(prev_top | curr_top) - turnovers.append(turnover) - - if turnovers: - performance.turnover_rate = np.mean(turnovers) - + performance.sharpe_ratio = returns_mean / returns_std * np.sqrt(252.0) + + performance.sample_size = len(usable_trade_dates) + performance.turnover_rate = _estimate_turnover_rate( + factor_column, + usable_trade_dates, + normalized_universe, + ) return performance @@ -233,3 +206,136 @@ def combine_factors( ) return FactorSpec(name, window) + + +def _normalize_universe(universe: Optional[Sequence[str]]) -> Optional[Tuple[str, ...]]: + if not universe: + return None + unique: Dict[str, None] = {} + for code in universe: + value = (code or "").strip().upper() + if value: + unique.setdefault(value, None) + return tuple(unique.keys()) if unique else None + + +def _has_factor_column(conn, column: str) -> bool: + rows = conn.execute("PRAGMA table_info(factors)").fetchall() + available = {row["name"] for row in rows} + return column in available + + +def _list_factor_dates(conn, start: str, end: str, universe: Optional[Tuple[str, ...]]) -> List[str]: + params: List[str] = [start, end] + query = ( + "SELECT DISTINCT trade_date FROM factors " + "WHERE trade_date BETWEEN ? AND ?" + ) + if universe: + placeholders = ",".join("?" for _ in universe) + query += f" AND ts_code IN ({placeholders})" + params.extend(universe) + query += " ORDER BY trade_date" + rows = conn.execute(query, params).fetchall() + return [row["trade_date"] for row in rows if row and row["trade_date"]] + + +def _fetch_factor_cross_section( + conn, + column: str, + trade_date: str, + universe: Optional[Tuple[str, ...]], +) -> Dict[str, float]: + params: List[str] = [trade_date] + query = f"SELECT ts_code, {column} AS value FROM factors WHERE trade_date = ? AND {column} IS NOT NULL" + if universe: + placeholders = ",".join("?" for _ in universe) + query += f" AND ts_code IN ({placeholders})" + params.extend(universe) + rows = conn.execute(query, params).fetchall() + result: Dict[str, float] = {} + for row in rows: + ts_code = row["ts_code"] + value = row["value"] + if ts_code is None or value is None: + continue + try: + numeric = float(value) + except (TypeError, ValueError): + continue + if not np.isfinite(numeric): + continue + result[ts_code] = numeric + return result + + +def _next_trade_date(conn, trade_date: str) -> Optional[str]: + row = conn.execute( + "SELECT MIN(trade_date) AS next_date FROM daily WHERE trade_date > ?", + (trade_date,), + ).fetchone() + next_date = row["next_date"] if row else None + return next_date + + +def _fetch_close_map(conn, trade_date: str, codes: Sequence[str]) -> Dict[str, float]: + if not codes: + return {} + placeholders = ",".join("?" for _ in codes) + params = [trade_date, *codes] + rows = conn.execute( + f""" + SELECT ts_code, close + FROM daily + WHERE trade_date = ? + AND ts_code IN ({placeholders}) + AND close IS NOT NULL + """, + params, + ).fetchall() + result: Dict[str, float] = {} + for row in rows: + ts_code = row["ts_code"] + value = row["close"] + if ts_code is None or value is None: + continue + try: + result[ts_code] = float(value) + except (TypeError, ValueError): + continue + return result + + +def _estimate_turnover_rate( + factor_name: str, + trade_dates: Sequence[str], + universe: Optional[Tuple[str, ...]], +) -> Optional[float]: + if not trade_dates: + return None + turnovers: List[float] = [] + for idx in range(1, len(trade_dates)): + prev_date = trade_dates[idx - 1] + curr_date = trade_dates[idx] + with db_session(read_only=True) as conn: + prev_map = _fetch_factor_cross_section(conn, factor_name, prev_date, universe) + curr_map = _fetch_factor_cross_section(conn, factor_name, curr_date, universe) + + if not prev_map or not curr_map: + continue + + prev_threshold = np.percentile(list(prev_map.values()), 80) + curr_threshold = np.percentile(list(curr_map.values()), 80) + prev_top = {code for code, value in prev_map.items() if value >= prev_threshold} + curr_top = {code for code, value in curr_map.items() if value >= curr_threshold} + if not prev_top and not curr_top: + continue + union = prev_top | curr_top + if not union: + continue + turnover = len(prev_top ^ curr_top) / len(union) + turnovers.append(turnover) + + if turnovers: + return float(np.mean(turnovers)) + return None diff --git a/app/features/factors.py b/app/features/factors.py index 63146db..e7f9389 100644 --- a/app/features/factors.py +++ b/app/features/factors.py @@ -706,7 +706,9 @@ def _compute_security_factors( "daily_basic.pe", "daily_basic.pb", "daily_basic.ps", + "daily_basic.turnover_rate", "daily_basic.volume_ratio", + "daily.close", "daily.amount", "daily.vol", "daily_basic.dv_ratio", # 股息率用于扩展因子 diff --git a/app/ui/progress_state.py b/app/ui/progress_state.py index 1d1833f..230512a 100644 --- a/app/ui/progress_state.py +++ b/app/ui/progress_state.py @@ -37,12 +37,14 @@ class FactorProgressState: total_batches: 总批次数 """ now = time.time() + normalized_total = max(total_securities, 0) + normalized_batches = max(total_batches, 1) if total_batches else 1 st.session_state.factor_progress.update({ 'current': 0, - 'total': max(total_securities, 0), + 'total': normalized_total, 'percentage': 0.0, 'current_batch': 0, - 'total_batches': max(total_batches, 0), + 'total_batches': normalized_batches, 'status': 'running', 'message': '开始因子计算...', 'start_time': now, @@ -74,8 +76,9 @@ class FactorProgressState: elapsed = 0.0 # 更新状态 + clamped_current = max(0, min(current_securities, total)) if total > 0 else max(0, current_securities) progress.update({ - 'current': current_securities, + 'current': clamped_current, 'current_batch': current_batch, 'percentage': percentage, 'message': message or f'处理批次 {current_batch}/{progress["total_batches"] or 1}', @@ -95,9 +98,10 @@ class FactorProgressState: elapsed = max(0.0, time.time() - start_time) else: elapsed = progress.get('elapsed_time', 0.0) or 0.0 + total = progress.get('total', 0) progress.update({ - 'current': progress.get('total', 0), - 'percentage': 100.0 if progress.get('total', 0) else progress.get('percentage', 0.0), + 'current': total, + 'percentage': 100.0 if total else progress.get('percentage', 0.0), 'status': 'completed', 'message': message, 'elapsed_time': elapsed, diff --git a/app/ui/views/factor_calculation.py b/app/ui/views/factor_calculation.py index f1eb5ee..3807605 100644 --- a/app/ui/views/factor_calculation.py +++ b/app/ui/views/factor_calculation.py @@ -229,7 +229,26 @@ def render_factor_calculation() -> None: value=True, help="如果勾选,将跳过数据库中已存在的因子计算结果" ) - + + st.markdown("##### 数据维护") + maintenance_col1, maintenance_col2 = st.columns([1, 2]) + with maintenance_col1: + clear_confirm = st.checkbox("确认清空因子表", key="factor_clear_confirm") + with maintenance_col2: + if st.button("清空因子表数据", disabled=not clear_confirm): + try: + with db_session() as conn: + conn.execute("DELETE FROM factors") + st.session_state.pop('factor_calculation_results', None) + st.session_state.pop('factor_calculation_error', None) + factor_progress.reset() + st.success("因子表数据已清空。") + except Exception as exc: # noqa: BLE001 + LOGGER.exception("清空因子表失败", extra={**LOG_EXTRA, "error": str(exc)}) + st.error(f"清空因子表失败:{exc}") + finally: + st.session_state['factor_clear_confirm'] = False + # 5. 开始计算按钮 if st.button("开始计算因子", disabled=not selected_factors): # 重置状态 diff --git a/app/ui/views/stock_eval.py b/app/ui/views/stock_eval.py index 18e8558..49cb9d3 100644 --- a/app/ui/views/stock_eval.py +++ b/app/ui/views/stock_eval.py @@ -203,6 +203,7 @@ def render_stock_evaluation() -> None: "IC信息比率": performance.ic_ir, "夏普比率": performance.sharpe_ratio, "换手率": performance.turnover_rate, + "有效样本数": performance.sample_size, }) st.session_state.evaluation_results = results @@ -251,6 +252,8 @@ def render_stock_evaluation() -> None: display_df["换手率"] = display_df["换手率"].map( lambda v: "N/A" if v is None else f"{v * 100:.1f}%" ) + if "有效样本数" in display_df: + display_df["有效样本数"] = display_df["有效样本数"].astype(int) st.dataframe( display_df, hide_index=True, @@ -260,31 +263,64 @@ def render_stock_evaluation() -> None: st.info("未产生任何因子评估结果。") # 绘制IC均值分布 - ic_means = result_df["IC均值"].astype(float).tolist() if not result_df.empty else [] + factor_names = result_df["因子"].tolist() if not result_df.empty else [] + ic_series = result_df["IC均值"].astype(float) if not result_df.empty else pd.Series(dtype=float) + if "有效样本数" in result_df: + sample_series = result_df["有效样本数"].astype(int) + ic_series = ic_series.where(sample_series > 0) + ic_means = ic_series.tolist() chart_df = pd.DataFrame({ - "因子": [r["因子"] for r in results], + "因子": factor_names, "IC均值": ic_means }) st.bar_chart(chart_df.set_index("因子")) - if not ic_means: + if not factor_names: st.info("暂无足够的 IC 数据,无法生成股票评分。") return + ic_array = np.array(ic_means, dtype=float) + usable_indices = [idx for idx, value in enumerate(ic_array) if np.isfinite(value)] + if not usable_indices: + st.info("所有因子 IC 均值均不可用,请先补充因子数据再评估。") + return + + usable_factors = [factor_names[idx] for idx in usable_indices] + usable_ic = ic_array[usable_indices] + + dropped_factors = [factor_names[idx] for idx, value in enumerate(ic_array) if not np.isfinite(value)] + if dropped_factors: + st.caption(f"已忽略缺少有效 IC 数据的因子:{', '.join(dropped_factors)}") + with st.spinner("正在生成股票评分..."): - if all(mean == 0 for mean in ic_means): - factor_weights = [1.0 / len(ic_means)] * len(ic_means) - LOGGER.info("所有因子IC均值均为零,使用均匀权重", extra=LOG_EXTRA) + if np.all(np.abs(usable_ic) <= 1e-9): + factor_weights = np.full(usable_ic.shape, 1.0 / usable_ic.size, dtype=float) + LOGGER.info("有效因子IC均值均为零,使用均匀权重", extra=LOG_EXTRA) else: - abs_sum = sum(abs(m) for m in ic_means) or 1.0 - factor_weights = [m / abs_sum for m in ic_means] - LOGGER.info("使用IC均值作为权重: %s", factor_weights, extra=LOG_EXTRA) + abs_sum = float(np.sum(np.abs(usable_ic))) + if abs_sum <= 1e-9: + factor_weights = np.full(usable_ic.shape, 1.0 / usable_ic.size, dtype=float) + LOGGER.info("有效因子IC均值绝对和过小,使用均匀权重", extra=LOG_EXTRA) + else: + factor_weights = usable_ic / abs_sum + LOGGER.info("使用IC均值作为权重: %s", factor_weights.tolist(), extra=LOG_EXTRA) + + weight_mask = np.abs(factor_weights) > 1e-6 + filtered_factors = [name for name, flag in zip(usable_factors, weight_mask) if flag] + filtered_weights = [float(weight) for weight, flag in zip(factor_weights, weight_mask) if flag] + + if not filtered_factors: + st.info("因子权重有效值均为零,无法生成股票评分。") + return + if len(filtered_factors) < len(usable_factors): + dropped_names = [name for name, flag in zip(usable_factors, weight_mask) if not flag] + LOGGER.info("已忽略权重为零的因子:%s", dropped_names, extra=LOG_EXTRA) scores = _calculate_stock_scores( universe, - selected_factors, + filtered_factors, end_date, - factor_weights + filtered_weights, ) if scores: @@ -319,6 +355,18 @@ def _calculate_stock_scores( LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "stock_evaluation"} + if not factors: + LOGGER.warning("因子列表为空,无法计算股票评分", extra=LOG_EXTRA) + return [] + if len(factors) != len(factor_weights): + LOGGER.error( + "因子数量与权重数量不一致 factors=%s weights=%s", + len(factors), + len(factor_weights), + extra=LOG_EXTRA, + ) + return [] + broker = DataBroker() trade_date_str = eval_date.strftime("%Y%m%d") diff --git a/app/utils/data_access.py b/app/utils/data_access.py index 9525d70..c7102b5 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -1413,20 +1413,21 @@ class DataBroker: try: # 获取股票基本信息 - info = self.fetch_latest( + raw_info = self.fetch_latest( ts_code=ts_code, trade_date=trade_date, fields=["stock_basic.name", "stock_basic.industry"] ) - - if not info: + if not raw_info: return None - - # 添加股票代码 - result = {"ts_code": ts_code} - result.update(info) - - return result + + info: Dict[str, Any] = {"ts_code": ts_code} + for key, value in raw_info.items(): + if key == "ts_code": + continue + alias = key.split(".", 1)[-1] if isinstance(key, str) and "." in key else key + info[alias] = value + return info except Exception as exc: LOGGER.debug( "获取股票信息失败 ts_code=%s err=%s",