refactor factor evaluation and add sample size tracking
This commit is contained in:
parent
43c70f3f7f
commit
8aa8efb651
@ -1,5 +1,5 @@
|
||||
"""Factor performance evaluation utilities."""
|
||||
from datetime import date, timedelta
|
||||
from datetime import date
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -7,12 +7,10 @@ from scipy import stats
|
||||
|
||||
from app.features.factors import (
|
||||
DEFAULT_FACTORS,
|
||||
FactorResult,
|
||||
FactorSpec,
|
||||
compute_factor_range,
|
||||
lookup_factor_spec,
|
||||
)
|
||||
from app.utils.data_access import DataBroker
|
||||
from app.utils.db import db_session
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
@ -29,6 +27,7 @@ class FactorPerformance:
|
||||
self.return_spreads: List[float] = []
|
||||
self.sharpe_ratio: Optional[float] = None
|
||||
self.turnover_rate: Optional[float] = None
|
||||
self.sample_size: int = 0
|
||||
|
||||
@property
|
||||
def ic_mean(self) -> float:
|
||||
@ -58,7 +57,8 @@ class FactorPerformance:
|
||||
"ic_ir": self.ic_ir,
|
||||
"rank_ic_mean": self.rank_ic_mean,
|
||||
"sharpe_ratio": self.sharpe_ratio or 0.0,
|
||||
"turnover_rate": self.turnover_rate or 0.0
|
||||
"turnover_rate": self.turnover_rate or 0.0,
|
||||
"sample_size": float(self.sample_size),
|
||||
}
|
||||
|
||||
|
||||
@ -80,132 +80,105 @@ def evaluate_factor(
|
||||
因子表现评估结果
|
||||
"""
|
||||
performance = FactorPerformance(factor_name)
|
||||
spec = lookup_factor_spec(factor_name)
|
||||
factor_column = factor_name
|
||||
|
||||
# 导入进度状态模块
|
||||
from app.ui.progress_state import factor_progress
|
||||
if spec is None:
|
||||
LOGGER.warning("未找到因子定义,仍尝试从数据库读取 factor=%s", factor_name, extra=LOG_EXTRA)
|
||||
|
||||
# 开始因子计算进度(在异步线程中不直接访问factor_progress)
|
||||
# factor_progress.start_calculation(
|
||||
# total_securities=len(universe) if universe else 0,
|
||||
# message=f"开始评估因子 {factor_name}"
|
||||
# )
|
||||
normalized_universe = _normalize_universe(universe)
|
||||
start_str = start_date.strftime("%Y%m%d")
|
||||
end_str = end_date.strftime("%Y%m%d")
|
||||
|
||||
try:
|
||||
spec = lookup_factor_spec(factor_name) or FactorSpec(factor_name, 0)
|
||||
with db_session(read_only=True) as conn:
|
||||
if not _has_factor_column(conn, factor_column):
|
||||
LOGGER.warning("factors 表缺少列 %s,跳过评估", factor_column, extra=LOG_EXTRA)
|
||||
return performance
|
||||
trade_dates = _list_factor_dates(conn, start_str, end_str, normalized_universe)
|
||||
|
||||
factor_results = compute_factor_range(
|
||||
start_date,
|
||||
end_date,
|
||||
factors=[spec],
|
||||
ts_codes=universe,
|
||||
skip_existing=True,
|
||||
)
|
||||
if not trade_dates:
|
||||
LOGGER.info("指定区间内未找到可用因子数据 factor=%s", factor_name, extra=LOG_EXTRA)
|
||||
return performance
|
||||
|
||||
# 因子计算完成(在异步线程中不直接访问factor_progress)
|
||||
# factor_progress.complete_calculation(
|
||||
# message=f"因子 {factor_name} 评估完成"
|
||||
# )
|
||||
usable_trade_dates: List[str] = []
|
||||
|
||||
except Exception as e:
|
||||
# 因子计算失败(在异步线程中不直接访问factor_progress)
|
||||
# factor_progress.complete_calculation(
|
||||
# message=f"因子 {factor_name} 评估失败: {str(e)}",
|
||||
# success=False
|
||||
# )
|
||||
raise
|
||||
for trade_date_str in trade_dates:
|
||||
with db_session(read_only=True) as conn:
|
||||
factor_map = _fetch_factor_cross_section(conn, factor_column, trade_date_str, normalized_universe)
|
||||
if not factor_map:
|
||||
continue
|
||||
next_trade = _next_trade_date(conn, trade_date_str)
|
||||
if not next_trade:
|
||||
continue
|
||||
curr_close = _fetch_close_map(conn, trade_date_str, factor_map.keys())
|
||||
next_close = _fetch_close_map(conn, next_trade, factor_map.keys())
|
||||
|
||||
# 按日期分组
|
||||
date_groups: Dict[date, List[FactorResult]] = {}
|
||||
for result in factor_results:
|
||||
if result.trade_date not in date_groups:
|
||||
date_groups[result.trade_date] = []
|
||||
date_groups[result.trade_date].append(result)
|
||||
factor_values: List[float] = []
|
||||
returns: List[float] = []
|
||||
for ts_code, value in factor_map.items():
|
||||
curr = curr_close.get(ts_code)
|
||||
nxt = next_close.get(ts_code)
|
||||
if curr is None or nxt is None or curr <= 0:
|
||||
continue
|
||||
factor_values.append(value)
|
||||
returns.append((nxt - curr) / curr)
|
||||
|
||||
# 计算每日IC值和RankIC值
|
||||
broker = DataBroker()
|
||||
for curr_date, results in sorted(date_groups.items()):
|
||||
next_date = curr_date + timedelta(days=1)
|
||||
|
||||
# 获取因子值和次日收益率
|
||||
factor_values = []
|
||||
next_returns = []
|
||||
|
||||
for result in results:
|
||||
factor_val = result.values.get(factor_name)
|
||||
if factor_val is None:
|
||||
if len(factor_values) < 20:
|
||||
continue
|
||||
|
||||
# 获取次日收益率
|
||||
next_close = broker.fetch_latest(
|
||||
result.ts_code,
|
||||
next_date.strftime("%Y%m%d"),
|
||||
["daily.close"]
|
||||
).get("daily.close")
|
||||
values_array = np.array(factor_values, dtype=float)
|
||||
returns_array = np.array(returns, dtype=float)
|
||||
if np.ptp(values_array) <= 1e-9 or np.ptp(returns_array) <= 1e-9:
|
||||
LOGGER.debug(
|
||||
"因子/收益序列波动不足,跳过 date=%s span_factor=%.6f span_return=%.6f",
|
||||
trade_date_str,
|
||||
float(np.ptp(values_array)),
|
||||
float(np.ptp(returns_array)),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
|
||||
curr_close = broker.fetch_latest(
|
||||
result.ts_code,
|
||||
curr_date.strftime("%Y%m%d"),
|
||||
["daily.close"]
|
||||
).get("daily.close")
|
||||
try:
|
||||
ic, _ = stats.pearsonr(values_array, returns_array)
|
||||
rank_ic, _ = stats.spearmanr(values_array, returns_array)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.debug("IC 计算失败 date=%s err=%s", trade_date_str, exc, extra=LOG_EXTRA)
|
||||
continue
|
||||
|
||||
if next_close and curr_close and curr_close > 0:
|
||||
ret = (next_close - curr_close) / curr_close
|
||||
factor_values.append(factor_val)
|
||||
next_returns.append(ret)
|
||||
if not (np.isfinite(ic) and np.isfinite(rank_ic)):
|
||||
LOGGER.debug(
|
||||
"相关系数结果无效 date=%s ic=%s rank_ic=%s",
|
||||
trade_date_str,
|
||||
ic,
|
||||
rank_ic,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
|
||||
if len(factor_values) >= 20: # 需要足够多的样本
|
||||
# 计算IC
|
||||
ic, _ = stats.pearsonr(factor_values, next_returns)
|
||||
performance.ic_series.append(ic)
|
||||
|
||||
# 计算RankIC
|
||||
rank_ic, _ = stats.spearmanr(factor_values, next_returns)
|
||||
performance.rank_ic_series.append(rank_ic)
|
||||
usable_trade_dates.append(trade_date_str)
|
||||
|
||||
# 计算多空组合收益
|
||||
sorted_pairs = sorted(zip(factor_values, next_returns),
|
||||
key=lambda x: x[0])
|
||||
n = len(sorted_pairs) // 5 # 五分位
|
||||
if n > 0:
|
||||
top_returns = [r for _, r in sorted_pairs[-n:]]
|
||||
bottom_returns = [r for _, r in sorted_pairs[:n]]
|
||||
spread = np.mean(top_returns) - np.mean(bottom_returns)
|
||||
sorted_pairs = sorted(zip(values_array.tolist(), returns_array.tolist()), key=lambda item: item[0])
|
||||
quantile = len(sorted_pairs) // 5
|
||||
if quantile > 0:
|
||||
top_returns = [ret for _, ret in sorted_pairs[-quantile:]]
|
||||
bottom_returns = [ret for _, ret in sorted_pairs[:quantile]]
|
||||
spread = float(np.mean(top_returns) - np.mean(bottom_returns))
|
||||
performance.return_spreads.append(spread)
|
||||
|
||||
# 计算Sharpe比率
|
||||
if performance.return_spreads:
|
||||
annual_factor = np.sqrt(252) # 交易日数
|
||||
returns_mean = np.mean(performance.return_spreads)
|
||||
returns_std = np.std(performance.return_spreads)
|
||||
returns_mean = float(np.mean(performance.return_spreads))
|
||||
returns_std = float(np.std(performance.return_spreads))
|
||||
if returns_std > 0:
|
||||
performance.sharpe_ratio = returns_mean / returns_std * annual_factor
|
||||
|
||||
# 估算换手率
|
||||
if factor_results:
|
||||
dates = sorted(date_groups.keys())
|
||||
turnovers = []
|
||||
for i in range(1, len(dates)):
|
||||
prev_results = date_groups[dates[i-1]]
|
||||
curr_results = date_groups[dates[i]]
|
||||
|
||||
# 计算组合变化
|
||||
prev_top = {r.ts_code for r in prev_results
|
||||
if r.values.get(factor_name, float('-inf')) > np.percentile(
|
||||
[res.values.get(factor_name, float('-inf'))
|
||||
for res in prev_results], 80)}
|
||||
curr_top = {r.ts_code for r in curr_results
|
||||
if r.values.get(factor_name, float('-inf')) > np.percentile(
|
||||
[res.values.get(factor_name, float('-inf'))
|
||||
for res in curr_results], 80)}
|
||||
|
||||
# 计算换手率
|
||||
if prev_top and curr_top:
|
||||
turnover = len(prev_top ^ curr_top) / len(prev_top | curr_top)
|
||||
turnovers.append(turnover)
|
||||
|
||||
if turnovers:
|
||||
performance.turnover_rate = np.mean(turnovers)
|
||||
performance.sharpe_ratio = returns_mean / returns_std * np.sqrt(252.0)
|
||||
|
||||
performance.sample_size = len(usable_trade_dates)
|
||||
performance.turnover_rate = _estimate_turnover_rate(
|
||||
factor_column,
|
||||
usable_trade_dates,
|
||||
normalized_universe,
|
||||
)
|
||||
return performance
|
||||
|
||||
|
||||
@ -233,3 +206,136 @@ def combine_factors(
|
||||
)
|
||||
|
||||
return FactorSpec(name, window)
|
||||
|
||||
|
||||
def _normalize_universe(universe: Optional[Sequence[str]]) -> Optional[Tuple[str, ...]]:
|
||||
if not universe:
|
||||
return None
|
||||
unique: Dict[str, None] = {}
|
||||
for code in universe:
|
||||
value = (code or "").strip().upper()
|
||||
if value:
|
||||
unique.setdefault(value, None)
|
||||
return tuple(unique.keys()) if unique else None
|
||||
|
||||
|
||||
def _has_factor_column(conn, column: str) -> bool:
|
||||
rows = conn.execute("PRAGMA table_info(factors)").fetchall()
|
||||
available = {row["name"] for row in rows}
|
||||
return column in available
|
||||
|
||||
|
||||
def _list_factor_dates(conn, start: str, end: str, universe: Optional[Tuple[str, ...]]) -> List[str]:
|
||||
params: List[str] = [start, end]
|
||||
query = (
|
||||
"SELECT DISTINCT trade_date FROM factors "
|
||||
"WHERE trade_date BETWEEN ? AND ?"
|
||||
)
|
||||
if universe:
|
||||
placeholders = ",".join("?" for _ in universe)
|
||||
query += f" AND ts_code IN ({placeholders})"
|
||||
params.extend(universe)
|
||||
query += " ORDER BY trade_date"
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return [row["trade_date"] for row in rows if row and row["trade_date"]]
|
||||
|
||||
|
||||
def _fetch_factor_cross_section(
|
||||
conn,
|
||||
column: str,
|
||||
trade_date: str,
|
||||
universe: Optional[Tuple[str, ...]],
|
||||
) -> Dict[str, float]:
|
||||
params: List[str] = [trade_date]
|
||||
query = f"SELECT ts_code, {column} AS value FROM factors WHERE trade_date = ? AND {column} IS NOT NULL"
|
||||
if universe:
|
||||
placeholders = ",".join("?" for _ in universe)
|
||||
query += f" AND ts_code IN ({placeholders})"
|
||||
params.extend(universe)
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
result: Dict[str, float] = {}
|
||||
for row in rows:
|
||||
ts_code = row["ts_code"]
|
||||
value = row["value"]
|
||||
if ts_code is None or value is None:
|
||||
continue
|
||||
try:
|
||||
numeric = float(value)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if not np.isfinite(numeric):
|
||||
continue
|
||||
result[ts_code] = numeric
|
||||
return result
|
||||
|
||||
|
||||
def _next_trade_date(conn, trade_date: str) -> Optional[str]:
|
||||
row = conn.execute(
|
||||
"SELECT MIN(trade_date) AS next_date FROM daily WHERE trade_date > ?",
|
||||
(trade_date,),
|
||||
).fetchone()
|
||||
next_date = row["next_date"] if row else None
|
||||
return next_date
|
||||
|
||||
|
||||
def _fetch_close_map(conn, trade_date: str, codes: Sequence[str]) -> Dict[str, float]:
|
||||
if not codes:
|
||||
return {}
|
||||
placeholders = ",".join("?" for _ in codes)
|
||||
params = [trade_date, *codes]
|
||||
rows = conn.execute(
|
||||
f"""
|
||||
SELECT ts_code, close
|
||||
FROM daily
|
||||
WHERE trade_date = ?
|
||||
AND ts_code IN ({placeholders})
|
||||
AND close IS NOT NULL
|
||||
""",
|
||||
params,
|
||||
).fetchall()
|
||||
result: Dict[str, float] = {}
|
||||
for row in rows:
|
||||
ts_code = row["ts_code"]
|
||||
value = row["close"]
|
||||
if ts_code is None or value is None:
|
||||
continue
|
||||
try:
|
||||
result[ts_code] = float(value)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
return result
|
||||
|
||||
|
||||
def _estimate_turnover_rate(
|
||||
factor_name: str,
|
||||
trade_dates: Sequence[str],
|
||||
universe: Optional[Tuple[str, ...]],
|
||||
) -> Optional[float]:
|
||||
if not trade_dates:
|
||||
return None
|
||||
turnovers: List[float] = []
|
||||
for idx in range(1, len(trade_dates)):
|
||||
prev_date = trade_dates[idx - 1]
|
||||
curr_date = trade_dates[idx]
|
||||
with db_session(read_only=True) as conn:
|
||||
prev_map = _fetch_factor_cross_section(conn, factor_name, prev_date, universe)
|
||||
curr_map = _fetch_factor_cross_section(conn, factor_name, curr_date, universe)
|
||||
|
||||
if not prev_map or not curr_map:
|
||||
continue
|
||||
|
||||
prev_threshold = np.percentile(list(prev_map.values()), 80)
|
||||
curr_threshold = np.percentile(list(curr_map.values()), 80)
|
||||
prev_top = {code for code, value in prev_map.items() if value >= prev_threshold}
|
||||
curr_top = {code for code, value in curr_map.items() if value >= curr_threshold}
|
||||
if not prev_top and not curr_top:
|
||||
continue
|
||||
union = prev_top | curr_top
|
||||
if not union:
|
||||
continue
|
||||
turnover = len(prev_top ^ curr_top) / len(union)
|
||||
turnovers.append(turnover)
|
||||
|
||||
if turnovers:
|
||||
return float(np.mean(turnovers))
|
||||
return None
|
||||
|
||||
@ -706,7 +706,9 @@ def _compute_security_factors(
|
||||
"daily_basic.pe",
|
||||
"daily_basic.pb",
|
||||
"daily_basic.ps",
|
||||
"daily_basic.turnover_rate",
|
||||
"daily_basic.volume_ratio",
|
||||
"daily.close",
|
||||
"daily.amount",
|
||||
"daily.vol",
|
||||
"daily_basic.dv_ratio", # 股息率用于扩展因子
|
||||
|
||||
@ -37,12 +37,14 @@ class FactorProgressState:
|
||||
total_batches: 总批次数
|
||||
"""
|
||||
now = time.time()
|
||||
normalized_total = max(total_securities, 0)
|
||||
normalized_batches = max(total_batches, 1) if total_batches else 1
|
||||
st.session_state.factor_progress.update({
|
||||
'current': 0,
|
||||
'total': max(total_securities, 0),
|
||||
'total': normalized_total,
|
||||
'percentage': 0.0,
|
||||
'current_batch': 0,
|
||||
'total_batches': max(total_batches, 0),
|
||||
'total_batches': normalized_batches,
|
||||
'status': 'running',
|
||||
'message': '开始因子计算...',
|
||||
'start_time': now,
|
||||
@ -74,8 +76,9 @@ class FactorProgressState:
|
||||
elapsed = 0.0
|
||||
|
||||
# 更新状态
|
||||
clamped_current = max(0, min(current_securities, total)) if total > 0 else max(0, current_securities)
|
||||
progress.update({
|
||||
'current': current_securities,
|
||||
'current': clamped_current,
|
||||
'current_batch': current_batch,
|
||||
'percentage': percentage,
|
||||
'message': message or f'处理批次 {current_batch}/{progress["total_batches"] or 1}',
|
||||
@ -95,9 +98,10 @@ class FactorProgressState:
|
||||
elapsed = max(0.0, time.time() - start_time)
|
||||
else:
|
||||
elapsed = progress.get('elapsed_time', 0.0) or 0.0
|
||||
total = progress.get('total', 0)
|
||||
progress.update({
|
||||
'current': progress.get('total', 0),
|
||||
'percentage': 100.0 if progress.get('total', 0) else progress.get('percentage', 0.0),
|
||||
'current': total,
|
||||
'percentage': 100.0 if total else progress.get('percentage', 0.0),
|
||||
'status': 'completed',
|
||||
'message': message,
|
||||
'elapsed_time': elapsed,
|
||||
|
||||
@ -230,6 +230,25 @@ def render_factor_calculation() -> None:
|
||||
help="如果勾选,将跳过数据库中已存在的因子计算结果"
|
||||
)
|
||||
|
||||
st.markdown("##### 数据维护")
|
||||
maintenance_col1, maintenance_col2 = st.columns([1, 2])
|
||||
with maintenance_col1:
|
||||
clear_confirm = st.checkbox("确认清空因子表", key="factor_clear_confirm")
|
||||
with maintenance_col2:
|
||||
if st.button("清空因子表数据", disabled=not clear_confirm):
|
||||
try:
|
||||
with db_session() as conn:
|
||||
conn.execute("DELETE FROM factors")
|
||||
st.session_state.pop('factor_calculation_results', None)
|
||||
st.session_state.pop('factor_calculation_error', None)
|
||||
factor_progress.reset()
|
||||
st.success("因子表数据已清空。")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("清空因子表失败", extra={**LOG_EXTRA, "error": str(exc)})
|
||||
st.error(f"清空因子表失败:{exc}")
|
||||
finally:
|
||||
st.session_state['factor_clear_confirm'] = False
|
||||
|
||||
# 5. 开始计算按钮
|
||||
if st.button("开始计算因子", disabled=not selected_factors):
|
||||
# 重置状态
|
||||
|
||||
@ -203,6 +203,7 @@ def render_stock_evaluation() -> None:
|
||||
"IC信息比率": performance.ic_ir,
|
||||
"夏普比率": performance.sharpe_ratio,
|
||||
"换手率": performance.turnover_rate,
|
||||
"有效样本数": performance.sample_size,
|
||||
})
|
||||
|
||||
st.session_state.evaluation_results = results
|
||||
@ -251,6 +252,8 @@ def render_stock_evaluation() -> None:
|
||||
display_df["换手率"] = display_df["换手率"].map(
|
||||
lambda v: "N/A" if v is None else f"{v * 100:.1f}%"
|
||||
)
|
||||
if "有效样本数" in display_df:
|
||||
display_df["有效样本数"] = display_df["有效样本数"].astype(int)
|
||||
st.dataframe(
|
||||
display_df,
|
||||
hide_index=True,
|
||||
@ -260,31 +263,64 @@ def render_stock_evaluation() -> None:
|
||||
st.info("未产生任何因子评估结果。")
|
||||
|
||||
# 绘制IC均值分布
|
||||
ic_means = result_df["IC均值"].astype(float).tolist() if not result_df.empty else []
|
||||
factor_names = result_df["因子"].tolist() if not result_df.empty else []
|
||||
ic_series = result_df["IC均值"].astype(float) if not result_df.empty else pd.Series(dtype=float)
|
||||
if "有效样本数" in result_df:
|
||||
sample_series = result_df["有效样本数"].astype(int)
|
||||
ic_series = ic_series.where(sample_series > 0)
|
||||
ic_means = ic_series.tolist()
|
||||
chart_df = pd.DataFrame({
|
||||
"因子": [r["因子"] for r in results],
|
||||
"因子": factor_names,
|
||||
"IC均值": ic_means
|
||||
})
|
||||
st.bar_chart(chart_df.set_index("因子"))
|
||||
|
||||
if not ic_means:
|
||||
if not factor_names:
|
||||
st.info("暂无足够的 IC 数据,无法生成股票评分。")
|
||||
return
|
||||
|
||||
ic_array = np.array(ic_means, dtype=float)
|
||||
usable_indices = [idx for idx, value in enumerate(ic_array) if np.isfinite(value)]
|
||||
if not usable_indices:
|
||||
st.info("所有因子 IC 均值均不可用,请先补充因子数据再评估。")
|
||||
return
|
||||
|
||||
usable_factors = [factor_names[idx] for idx in usable_indices]
|
||||
usable_ic = ic_array[usable_indices]
|
||||
|
||||
dropped_factors = [factor_names[idx] for idx, value in enumerate(ic_array) if not np.isfinite(value)]
|
||||
if dropped_factors:
|
||||
st.caption(f"已忽略缺少有效 IC 数据的因子:{', '.join(dropped_factors)}")
|
||||
|
||||
with st.spinner("正在生成股票评分..."):
|
||||
if all(mean == 0 for mean in ic_means):
|
||||
factor_weights = [1.0 / len(ic_means)] * len(ic_means)
|
||||
LOGGER.info("所有因子IC均值均为零,使用均匀权重", extra=LOG_EXTRA)
|
||||
if np.all(np.abs(usable_ic) <= 1e-9):
|
||||
factor_weights = np.full(usable_ic.shape, 1.0 / usable_ic.size, dtype=float)
|
||||
LOGGER.info("有效因子IC均值均为零,使用均匀权重", extra=LOG_EXTRA)
|
||||
else:
|
||||
abs_sum = sum(abs(m) for m in ic_means) or 1.0
|
||||
factor_weights = [m / abs_sum for m in ic_means]
|
||||
LOGGER.info("使用IC均值作为权重: %s", factor_weights, extra=LOG_EXTRA)
|
||||
abs_sum = float(np.sum(np.abs(usable_ic)))
|
||||
if abs_sum <= 1e-9:
|
||||
factor_weights = np.full(usable_ic.shape, 1.0 / usable_ic.size, dtype=float)
|
||||
LOGGER.info("有效因子IC均值绝对和过小,使用均匀权重", extra=LOG_EXTRA)
|
||||
else:
|
||||
factor_weights = usable_ic / abs_sum
|
||||
LOGGER.info("使用IC均值作为权重: %s", factor_weights.tolist(), extra=LOG_EXTRA)
|
||||
|
||||
weight_mask = np.abs(factor_weights) > 1e-6
|
||||
filtered_factors = [name for name, flag in zip(usable_factors, weight_mask) if flag]
|
||||
filtered_weights = [float(weight) for weight, flag in zip(factor_weights, weight_mask) if flag]
|
||||
|
||||
if not filtered_factors:
|
||||
st.info("因子权重有效值均为零,无法生成股票评分。")
|
||||
return
|
||||
if len(filtered_factors) < len(usable_factors):
|
||||
dropped_names = [name for name, flag in zip(usable_factors, weight_mask) if not flag]
|
||||
LOGGER.info("已忽略权重为零的因子:%s", dropped_names, extra=LOG_EXTRA)
|
||||
|
||||
scores = _calculate_stock_scores(
|
||||
universe,
|
||||
selected_factors,
|
||||
filtered_factors,
|
||||
end_date,
|
||||
factor_weights
|
||||
filtered_weights,
|
||||
)
|
||||
|
||||
if scores:
|
||||
@ -319,6 +355,18 @@ def _calculate_stock_scores(
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "stock_evaluation"}
|
||||
|
||||
if not factors:
|
||||
LOGGER.warning("因子列表为空,无法计算股票评分", extra=LOG_EXTRA)
|
||||
return []
|
||||
if len(factors) != len(factor_weights):
|
||||
LOGGER.error(
|
||||
"因子数量与权重数量不一致 factors=%s weights=%s",
|
||||
len(factors),
|
||||
len(factor_weights),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return []
|
||||
|
||||
broker = DataBroker()
|
||||
trade_date_str = eval_date.strftime("%Y%m%d")
|
||||
|
||||
|
||||
@ -1413,20 +1413,21 @@ class DataBroker:
|
||||
|
||||
try:
|
||||
# 获取股票基本信息
|
||||
info = self.fetch_latest(
|
||||
raw_info = self.fetch_latest(
|
||||
ts_code=ts_code,
|
||||
trade_date=trade_date,
|
||||
fields=["stock_basic.name", "stock_basic.industry"]
|
||||
)
|
||||
|
||||
if not info:
|
||||
if not raw_info:
|
||||
return None
|
||||
|
||||
# 添加股票代码
|
||||
result = {"ts_code": ts_code}
|
||||
result.update(info)
|
||||
|
||||
return result
|
||||
info: Dict[str, Any] = {"ts_code": ts_code}
|
||||
for key, value in raw_info.items():
|
||||
if key == "ts_code":
|
||||
continue
|
||||
alias = key.split(".", 1)[-1] if isinstance(key, str) and "." in key else key
|
||||
info[alias] = value
|
||||
return info
|
||||
except Exception as exc:
|
||||
LOGGER.debug(
|
||||
"获取股票信息失败 ts_code=%s err=%s",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user