"""股票筛选与评估视图。""" from datetime import date, datetime, timedelta from typing import Dict, List, Optional import json import sqlite3 import numpy as np import pandas as pd import streamlit as st from app.features.evaluation import evaluate_factor from app.features.factors import DEFAULT_FACTORS from app.features.validation import check_data_sufficiency from app.utils.config import get_config from app.utils.data_access import DataBroker from app.utils.db import db_session from app.utils.logging import get_logger LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "stock_eval"} def _ensure_investment_pool_schema(conn: sqlite3.Connection) -> None: """Ensure investment_pool table has latest optional columns.""" try: info = conn.execute("PRAGMA table_info(investment_pool)").fetchall() except sqlite3.Error: return columns = { (row["name"] if isinstance(row, sqlite3.Row) else row[1]) for row in info if row is not None } if "name" not in columns: try: conn.execute("ALTER TABLE investment_pool ADD COLUMN name TEXT") except sqlite3.Error: pass if "industry" not in columns: try: conn.execute("ALTER TABLE investment_pool ADD COLUMN industry TEXT") except sqlite3.Error: pass if "created_at" not in columns: try: conn.execute( "ALTER TABLE investment_pool ADD COLUMN created_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now'))" ) except sqlite3.Error: try: conn.execute("ALTER TABLE investment_pool ADD COLUMN created_at TEXT") except sqlite3.Error: pass def _get_latest_trading_date() -> date: """获取数据库中的最新交易日期""" with db_session(read_only=True) as conn: result = conn.execute( """ SELECT trade_date FROM daily_basic WHERE trade_date <= :today GROUP BY trade_date ORDER BY trade_date DESC LIMIT 1 """, {"today": datetime.now().strftime("%Y%m%d")} ).fetchone() if result and result[0]: return datetime.strptime(str(result[0]), "%Y%m%d").date() return datetime.now().date() - timedelta(days=1) # 如果查询失败才返回昨天 def _normalize_universe(universe: Optional[List[str]]) -> List[str]: """标准化股票代码列表,去重并转为大写。""" if not universe: return [] normalized: Dict[str, None] = {} for code in universe: candidate = (code or "").strip().upper() if candidate and candidate not in normalized: normalized[candidate] = None return list(normalized.keys()) def render_stock_evaluation() -> None: """渲染股票筛选与评估页面。""" LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "stock_evaluation_ui"} st.subheader("股票筛选与评估") # 记录页面加载 LOGGER.info("股票筛选与评估页面已加载", extra=LOG_EXTRA) # 1. 时间范围选择 col1, col2 = st.columns(2) with col1: latest_date = _get_latest_trading_date() end_date = st.date_input( "评估截止日期", value=latest_date, help="选择评估的截止日期" ) with col2: lookback_days = st.slider( "回溯天数", min_value=30, max_value=360, value=180, step=30, help="选择评估的历史数据长度" ) start_date = end_date - timedelta(days=lookback_days) # 2. 因子选择 st.markdown("##### 评估因子选择") factor_groups = { "动量类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("mom_")], "波动率类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("volat_")], "换手率类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("turn_")], "估值类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("val_")], "量价类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("volume_")], "市场类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("market_")] } # 定义默认选中的关键常用因子 DEFAULT_SELECTED_FACTORS = { "mom_5", # 5日动量 "mom_20", # 20日动量 "mom_60", # 60日动量 "volat_20", # 20日波动率 "turn_5", # 5日换手率 "turn_20", # 20日换手率 "val_pe_score", # PE评分 "val_pb_score", # PB评分 "volume_ratio_score", # 量比评分 "risk_penalty" # 风险惩罚项 } selected_factors = [] for group_name, factors in factor_groups.items(): if factors: st.markdown(f"###### {group_name}") cols = st.columns(3) for i, factor in enumerate(factors): if cols[i % 3].checkbox( factor.name, value=factor.name in DEFAULT_SELECTED_FACTORS, help=factor.description if hasattr(factor, 'description') else None ): selected_factors.append(factor.name) if not selected_factors: st.warning("请至少选择一个评估因子") return # 3. 股票池范围 st.markdown("##### 股票池范围") pool_type = st.radio( "选择股票池", ["沪深300", "中证500", "中证1000", "全部A股", "自定义"], index=0, # 默认选择沪深300 horizontal=True ) universe: Optional[List[str]] = None if pool_type != "全部A股": broker = DataBroker() if pool_type == "自定义": custom_codes = st.text_area( "输入股票代码列表(每行一个)", help="请输入股票代码,每行一个,例如: 000001.SZ" ) if custom_codes: universe = [ code.strip() for code in custom_codes.split("\n") if code.strip() ] else: # 从数据库获取对应指数成分股 index_code = { "沪深300": "000300.SH", "中证500": "000905.SH", "中证1000": "000852.SH" }[pool_type] universe = broker.get_index_stocks( index_code, end_date.strftime("%Y%m%d") ) universe = _normalize_universe(universe) if universe == []: universe = None # 4. 评估结果 # 初始化会话状态 if 'evaluation_results' not in st.session_state: st.session_state.evaluation_results = None if 'evaluation_status' not in st.session_state: st.session_state.evaluation_status = 'idle' # idle, running, completed, error if 'current_factor' not in st.session_state: st.session_state.current_factor = '' if 'progress' not in st.session_state: st.session_state.progress = 0 # 同步评估函数 def run_evaluation_sync(): try: # 记录评估开始 LOGGER.info( "开始因子评估 因子数量=%s 评估日期=%s 至 %s", len(selected_factors), start_date, end_date, extra=LOG_EXTRA ) st.session_state.evaluation_status = 'running' st.session_state.pop('evaluation_error', None) results = [] for i, factor_name in enumerate(selected_factors): st.session_state.current_factor = factor_name st.session_state.progress = ((i + 1) / len(selected_factors)) * 100 performance = evaluate_factor( factor_name, start_date, end_date, universe=universe ) results.append({ "因子": factor_name, "IC均值": performance.ic_mean, "RankIC均值": performance.rank_ic_mean, "IC信息比率": performance.ic_ir, "夏普比率": performance.sharpe_ratio, "换手率": performance.turnover_rate, "有效样本数": performance.sample_size, }) st.session_state.evaluation_results = results st.session_state.evaluation_status = 'completed' st.session_state.progress = 100 except Exception as e: st.session_state.evaluation_status = 'error' st.session_state.evaluation_error = str(e) # 显示进度 if st.session_state.evaluation_status == 'running': st.info(f"正在评估因子: {st.session_state.current_factor}") st.progress(st.session_state.progress / 100) elif st.session_state.evaluation_status == 'completed': st.success("因子评估完成!") elif st.session_state.evaluation_status == 'error': st.error(f"评估失败: {st.session_state.evaluation_error}") # 开始评估按钮 if st.button("开始评估", disabled=not selected_factors or st.session_state.evaluation_status == 'running'): # 重置状态 st.session_state.evaluation_results = None st.session_state.evaluation_status = 'running' st.session_state.progress = 0 # 直接调用同步评估函数 run_evaluation_sync() # 显示结果 if st.session_state.evaluation_results: results = st.session_state.evaluation_results st.markdown("##### 因子评估结果") result_df = pd.DataFrame(results) if not result_df.empty: display_df = result_df.copy() for col in ["IC均值", "RankIC均值", "IC信息比率"]: if col in display_df: display_df[col] = display_df[col].map(lambda v: f"{v:.4f}") if "夏普比率" in display_df: display_df["夏普比率"] = display_df["夏普比率"].map( lambda v: "N/A" if v is None else f"{v:.4f}" ) if "换手率" in display_df: display_df["换手率"] = display_df["换手率"].map( lambda v: "N/A" if v is None else f"{v * 100:.1f}%" ) if "有效样本数" in display_df: display_df["有效样本数"] = display_df["有效样本数"].astype(int) st.dataframe( display_df, hide_index=True, width="stretch" ) else: st.info("未产生任何因子评估结果。") # 绘制IC均值分布 factor_names = result_df["因子"].tolist() if not result_df.empty else [] ic_series = result_df["IC均值"].astype(float) if not result_df.empty else pd.Series(dtype=float) if "有效样本数" in result_df: sample_series = result_df["有效样本数"].astype(int) ic_series = ic_series.where(sample_series > 0) ic_means = ic_series.tolist() chart_df = pd.DataFrame({ "因子": factor_names, "IC均值": ic_means }) st.bar_chart(chart_df.set_index("因子")) if not factor_names: st.info("暂无足够的 IC 数据,无法生成股票评分。") return ic_array = np.array(ic_means, dtype=float) usable_indices = [idx for idx, value in enumerate(ic_array) if np.isfinite(value)] if not usable_indices: st.info("所有因子 IC 均值均不可用,请先补充因子数据再评估。") return usable_factors = [factor_names[idx] for idx in usable_indices] usable_ic = ic_array[usable_indices] dropped_factors = [factor_names[idx] for idx, value in enumerate(ic_array) if not np.isfinite(value)] if dropped_factors: st.caption(f"已忽略缺少有效 IC 数据的因子:{', '.join(dropped_factors)}") with st.spinner("正在生成股票评分..."): if np.all(np.abs(usable_ic) <= 1e-9): factor_weights = np.full(usable_ic.shape, 1.0 / usable_ic.size, dtype=float) LOGGER.info("有效因子IC均值均为零,使用均匀权重", extra=LOG_EXTRA) else: abs_sum = float(np.sum(np.abs(usable_ic))) if abs_sum <= 1e-9: factor_weights = np.full(usable_ic.shape, 1.0 / usable_ic.size, dtype=float) LOGGER.info("有效因子IC均值绝对和过小,使用均匀权重", extra=LOG_EXTRA) else: factor_weights = usable_ic / abs_sum LOGGER.info("使用IC均值作为权重: %s", factor_weights.tolist(), extra=LOG_EXTRA) weight_mask = np.abs(factor_weights) > 1e-6 filtered_factors = [name for name, flag in zip(usable_factors, weight_mask) if flag] filtered_weights = [float(weight) for weight, flag in zip(factor_weights, weight_mask) if flag] if not filtered_factors: st.info("因子权重有效值均为零,无法生成股票评分。") return if len(filtered_factors) < len(usable_factors): dropped_names = [name for name, flag in zip(usable_factors, weight_mask) if not flag] LOGGER.info("已忽略权重为零的因子:%s", dropped_names, extra=LOG_EXTRA) scores = _calculate_stock_scores( universe, filtered_factors, end_date, filtered_weights, ) if scores: st.markdown("##### 股票综合评分 (Top 20)") score_df = pd.DataFrame(scores).sort_values( "综合评分", ascending=False ) top_df = score_df.head(20).reset_index(drop=True) display_scores = top_df.copy() display_scores["综合评分"] = display_scores["综合评分"].map(lambda v: f"{v:.4f}") st.dataframe( display_scores, hide_index=True, width="stretch" ) if st.button("将Top 20股票加入股票池"): _add_to_stock_pool(top_df, end_date) st.success("已成功将选中股票加入股票池!") else: st.info("无法根据当前因子权重生成有效的股票评分结果。") def _calculate_stock_scores( universe: Optional[List[str]], factors: List[str], eval_date: date, factor_weights: List[float] ) -> List[Dict[str, object]]: """计算股票的综合评分。""" LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "stock_evaluation"} if not factors: LOGGER.warning("因子列表为空,无法计算股票评分", extra=LOG_EXTRA) return [] if len(factors) != len(factor_weights): LOGGER.error( "因子数量与权重数量不一致 factors=%s weights=%s", len(factors), len(factor_weights), extra=LOG_EXTRA, ) return [] broker = DataBroker() trade_date_str = eval_date.strftime("%Y%m%d") # 记录评估开始 LOGGER.info( "开始股票评估评估日期=%s 因子数量=%d 权重=%s", eval_date.strftime("%Y-%m-%d"), len(factors), factor_weights, extra=LOG_EXTRA ) # 标准化权重 weights = np.array(factor_weights, dtype=float) abs_sum = np.sum(np.abs(weights)) if abs_sum > 0: # 避免除以零 weights = weights / abs_sum else: # 如果所有权重都是零,则使用均匀分布 weights = np.ones_like(weights) / len(weights) # 获取所有股票的因子值 stocks = universe or broker.get_all_stocks(trade_date_str) if not stocks: LOGGER.warning("股票列表为空,无法生成评分", extra=LOG_EXTRA) return [] # 记录股票列表信息 LOGGER.info( "获取股票列表 universe_size=%d total_stocks=%d", len(universe) if universe else 0, len(stocks), extra=LOG_EXTRA ) results = [] evaluated_count = 0 skipped_count = 0 factor_fields = [f"factors.{name}" for name in factors] for ts_code in stocks: if not check_data_sufficiency(ts_code, trade_date_str): skipped_count += 1 continue latest_payload = broker.fetch_latest( ts_code, trade_date_str, factor_fields, auto_refresh=False, ) if not latest_payload: skipped_count += 1 continue factor_values: List[float] = [] missing = False for field in factor_fields: value = latest_payload.get(field) if value is None: missing = True break try: factor_values.append(float(value)) except (TypeError, ValueError): missing = True break if missing or len(factor_values) != len(factors): skipped_count += 1 continue info = broker.get_stock_info(ts_code, trade_date_str) if not info: skipped_count += 1 continue score = float(np.dot(factor_values, weights)) evaluated_count += 1 results.append({ "股票代码": ts_code, "股票名称": info.get("name", ""), "行业": info.get("industry", ""), "综合评分": score, }) # 记录评估完成信息 LOGGER.info( "股票评估完成 总股票数=%d 已评估=%d 跳过=%d 结果数=%d", len(stocks), evaluated_count, skipped_count, len(results), extra=LOG_EXTRA ) return results def _add_to_stock_pool( score_df: pd.DataFrame, eval_date: date ) -> None: """将股票评分结果写入投资池。""" broker = DataBroker() trade_date = eval_date.strftime("%Y%m%d") payload: List[tuple] = [] ranked_df = score_df.reset_index(drop=True) for rank, row in ranked_df.iterrows(): tags = json.dumps(["stock_evaluation", "top20"], ensure_ascii=False) metadata = json.dumps( { "source": "stock_evaluation", "rank": rank + 1, "score": float(row["综合评分"]), }, ensure_ascii=False, ) # 获取股票基本信息 stock_info = broker.get_stock_info(row["股票代码"], trade_date) stock_name = stock_info.get("name", "") if stock_info else "" stock_industry = stock_info.get("industry", "") if stock_info else "" payload.append( ( trade_date, row["股票代码"], float(row["综合评分"]), "candidate", "factor_evaluation_top20", tags, metadata, stock_name, stock_industry, ) ) with db_session() as conn: _ensure_investment_pool_schema(conn) conn.execute("DELETE FROM investment_pool WHERE trade_date = ?", (trade_date,)) if payload: conn.executemany( """ INSERT INTO investment_pool ( trade_date, ts_code, score, status, rationale, tags, metadata, name, industry ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, payload, )