diff --git a/README.md b/README.md index 0e87572..01e0a4e 100644 --- a/README.md +++ b/README.md @@ -151,5 +151,5 @@ Streamlit `自检测试` 页签提供: TODO 1. 在选股时,因子都已经提前算好,不需要再计算了,直接用就行。 2. 因子计算的公式再确认下 -3. 审查整个项目的代码逻辑,从main.py开始,逐字逐句检查。如一些重复的检查可以去掉;未实现的功能请标记TODO,并给出实现思路;错误的、低效率的调用请修正;代码结构性的问题请指出。 +3. 审查整个项目的代码逻辑,从app/ui/streamlit_app.py开始,逐字逐句检查。如一些重复的安全检查可以去掉;明显果实的临时性代码请删除掉;未实现的功能请标记TODO,并给出实现思路;错误的、低效率的调用请修正;代码结构性的问题请指出并尝试修正;复杂不清晰的代码结构请尝试重构; 4. 梳理整个项目的所有业务逻辑。针对每个业务,从业务实现角度评估代码功能是否存在问题,是否需要优化,是否需要重构。 diff --git a/app/features/evaluation.py b/app/features/evaluation.py index 5eea557..3a83d11 100644 --- a/app/features/evaluation.py +++ b/app/features/evaluation.py @@ -9,7 +9,8 @@ from app.features.factors import ( DEFAULT_FACTORS, FactorResult, FactorSpec, - compute_factor_range + compute_factor_range, + lookup_factor_spec, ) from app.utils.data_access import DataBroker from app.utils.logging import get_logger @@ -90,14 +91,14 @@ def evaluate_factor( # ) try: - # 计算因子值 - # 设置 skip_existing=False,确保即使因子已存在也会重新计算 + spec = lookup_factor_spec(factor_name) or FactorSpec(factor_name, 0) + factor_results = compute_factor_range( start_date, end_date, - factors=[FactorSpec(factor_name, 0)], + factors=[spec], ts_codes=universe, - skip_existing=False + skip_existing=True, ) # 因子计算完成(在异步线程中不直接访问factor_progress) diff --git a/app/features/factors.py b/app/features/factors.py index 4669e0b..63146db 100644 --- a/app/features/factors.py +++ b/app/features/factors.py @@ -90,6 +90,17 @@ DEFAULT_FACTORS: List[FactorSpec] = [ FactorSpec("risk_penalty", 0), # 风险惩罚因子 ] +_FACTOR_SPEC_MAP: Dict[str, FactorSpec] = {spec.name: spec for spec in DEFAULT_FACTORS} + + +def lookup_factor_spec(name: str) -> Optional[FactorSpec]: + """Return a copy of the registered ``FactorSpec`` for ``name`` if available.""" + + base = _FACTOR_SPEC_MAP.get(name) + if base is None: + return None + return FactorSpec(name=base.name, window=base.window) + def compute_factors( trade_date: date, @@ -304,30 +315,33 @@ def _existing_factor_codes_with_factors(trade_date: str, factor_names: List[str] if not factor_names: return {} - # 构建检查条件 - conditions = [] - for name in factor_names: - conditions.append(f"json_extract(factors, '$.{name}') IS NOT NULL") - condition_str = " AND ".join(conditions) - - # 构建SQL查询 - query = """ - SELECT ts_code - FROM factors - WHERE trade_date = ? - AND """ + condition_str + """ - GROUP BY ts_code - """ - + valid_names = [ + name + for name in factor_names + if isinstance(name, str) and _IDENTIFIER_RE.match(name) + ] + if not valid_names: + return {} + with db_session(read_only=True) as conn: + columns = { + row["name"] + for row in conn.execute("PRAGMA table_info(factors)").fetchall() + } + selected = [name for name in valid_names if name in columns] + if not selected: + return {} + + predicates = " AND ".join(f"{col} IS NOT NULL" for col in selected) + query = ( + "SELECT ts_code FROM factors " + "WHERE trade_date = ? AND " + f"{predicates} " + "GROUP BY ts_code" + ) rows = conn.execute(query, (trade_date,)).fetchall() - - # 返回结果 - result = {} - for row in rows: - result[row["ts_code"]] = True - - return result + + return {row["ts_code"]: True for row in rows if row and row["ts_code"]} def _list_trade_dates( diff --git a/app/ui/progress_state.py b/app/ui/progress_state.py index 46012ea..1d1833f 100644 --- a/app/ui/progress_state.py +++ b/app/ui/progress_state.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Optional, Dict, Any +import time import streamlit as st @@ -25,7 +26,7 @@ class FactorProgressState: 'status': 'idle', # idle, running, completed, error 'message': '', 'start_time': None, - 'elapsed_time': 0 + 'elapsed_time': 0.0, } def start_calculation(self, total_securities: int, total_batches: int) -> None: @@ -35,16 +36,17 @@ class FactorProgressState: total_securities: 总证券数量 total_batches: 总批次数 """ + now = time.time() st.session_state.factor_progress.update({ 'current': 0, - 'total': total_securities, + 'total': max(total_securities, 0), 'percentage': 0.0, 'current_batch': 0, - 'total_batches': total_batches, + 'total_batches': max(total_batches, 0), 'status': 'running', 'message': '开始因子计算...', - 'start_time': st.session_state.get('factor_progress', {}).get('start_time'), - 'elapsed_time': 0 + 'start_time': now, + 'elapsed_time': 0.0, }) def update_progress(self, current_securities: int, current_batch: int, @@ -59,18 +61,26 @@ class FactorProgressState: progress = st.session_state.factor_progress # 计算百分比 - if progress['total'] > 0: - percentage = (current_securities / progress['total']) * 100 + total = progress.get('total', 0) or 0 + if total > 0: + percentage = (current_securities / total) * 100 else: percentage = 0.0 + + start_time = progress.get('start_time') + if isinstance(start_time, (int, float)): + elapsed = max(0.0, time.time() - start_time) + else: + elapsed = 0.0 # 更新状态 progress.update({ 'current': current_securities, 'current_batch': current_batch, 'percentage': percentage, - 'message': message or f'处理批次 {current_batch}/{progress["total_batches"]}', - 'status': 'running' + 'message': message or f'处理批次 {current_batch}/{progress["total_batches"] or 1}', + 'status': 'running', + 'elapsed_time': elapsed, }) def complete_calculation(self, message: str = '因子计算完成') -> None: @@ -80,11 +90,17 @@ class FactorProgressState: message: 完成消息 """ progress = st.session_state.factor_progress + start_time = progress.get('start_time') + if isinstance(start_time, (int, float)): + elapsed = max(0.0, time.time() - start_time) + else: + elapsed = progress.get('elapsed_time', 0.0) or 0.0 progress.update({ - 'current': progress['total'], - 'percentage': 100.0, + 'current': progress.get('total', 0), + 'percentage': 100.0 if progress.get('total', 0) else progress.get('percentage', 0.0), 'status': 'completed', - 'message': message + 'message': message, + 'elapsed_time': elapsed, }) def error_occurred(self, error_message: str) -> None: @@ -93,9 +109,16 @@ class FactorProgressState: Args: error_message: 错误消息 """ - st.session_state.factor_progress.update({ + progress = st.session_state.factor_progress + start_time = progress.get('start_time') + if isinstance(start_time, (int, float)): + elapsed = max(0.0, time.time() - start_time) + else: + elapsed = progress.get('elapsed_time', 0.0) or 0.0 + progress.update({ 'status': 'error', - 'message': f'错误: {error_message}' + 'message': f'错误: {error_message}', + 'elapsed_time': elapsed, }) def get_progress_info(self) -> Dict[str, Any]: @@ -118,7 +141,7 @@ class FactorProgressState: 'status': 'idle', 'message': '', 'start_time': None, - 'elapsed_time': 0 + 'elapsed_time': 0.0, } @@ -199,4 +222,4 @@ def is_factor_calculation_running() -> bool: Returns: 是否正在进行因子计算 """ - return factor_progress.get_progress_info()['status'] == 'running' \ No newline at end of file + return factor_progress.get_progress_info()['status'] == 'running' diff --git a/app/ui/views/backtest.py b/app/ui/views/backtest.py index 1b205b4..fe474f2 100644 --- a/app/ui/views/backtest.py +++ b/app/ui/views/backtest.py @@ -2,16 +2,11 @@ from __future__ import annotations import json -import uuid -from dataclasses import asdict from datetime import date, datetime from typing import Dict, List, Optional -import numpy as np import pandas as pd import plotly.express as px -import requests -from requests.exceptions import RequestException import streamlit as st from app.agents.base import AgentContext diff --git a/app/ui/views/factor_calculation.py b/app/ui/views/factor_calculation.py index ab787ff..f1eb5ee 100644 --- a/app/ui/views/factor_calculation.py +++ b/app/ui/views/factor_calculation.py @@ -1,19 +1,20 @@ """因子计算页面。""" -from datetime import datetime, timedelta -from typing import List, Optional +from datetime import date, datetime, timedelta +from typing import List, Optional, Sequence import streamlit as st -from app.features.factors import compute_factors, DEFAULT_FACTORS, FactorSpec +from app.features.factors import DEFAULT_FACTORS, FactorSpec, compute_factor_range from app.ui.progress_state import factor_progress +from app.ui.shared import LOGGER, LOG_EXTRA from app.utils.data_access import DataBroker from app.utils.db import db_session def _get_latest_trading_date() -> datetime.date: """获取数据库中的最新交易日期""" - with db_session() as session: - result = session.execute( + with db_session(read_only=True) as conn: + result = conn.execute( """ SELECT trade_date FROM daily_basic @@ -34,9 +35,9 @@ def _get_all_stocks() -> List[str]: """获取所有股票代码""" try: # 从daily表获取所有股票代码 - with db_session() as session: + with db_session(read_only=True) as conn: latest_date = _get_latest_trading_date() - result = session.execute( + result = conn.execute( """ SELECT DISTINCT ts_code FROM daily @@ -45,12 +46,88 @@ def _get_all_stocks() -> List[str]: {"trade_date": latest_date.strftime("%Y%m%d")} ).fetchall() - return [row[0] for row in result] if result else [] - except Exception as e: - st.error(f"获取股票列表失败: {str(e)}") + return [row["ts_code"] for row in result if row and row["ts_code"]] if result else [] + except Exception as exc: + LOGGER.exception("获取股票列表失败", extra={**LOG_EXTRA, "error": str(exc)}) + st.error(f"获取股票列表失败: {exc}") return [] +def _normalize_universe(universe: Optional[Sequence[str]]) -> List[str]: + """去重并规范股票代码格式。""" + if not universe: + return [] + seen: dict[str, None] = {} + for code in universe: + normalized = code.strip().upper() + if normalized and normalized not in seen: + seen[normalized] = None + return list(seen.keys()) + + +def _get_trade_dates_between( + start: date, + end: date, + universe: Optional[Sequence[str]] = None, +) -> List[date]: + """获取区间内存在行情数据的交易日期列表。""" + + if end < start: + return [] + + start_str = start.strftime("%Y%m%d") + end_str = end.strftime("%Y%m%d") + params: List[str] = [start_str, end_str] + query = ( + "SELECT DISTINCT trade_date FROM daily " + "WHERE trade_date BETWEEN ? AND ?" + ) + scoped_universe = _normalize_universe(universe) + if scoped_universe: + placeholders = ", ".join("?" for _ in scoped_universe) + query += f" AND ts_code IN ({placeholders})" + params.extend(scoped_universe) + query += " ORDER BY trade_date" + + with db_session(read_only=True) as conn: + rows = conn.execute(query, params).fetchall() + + return [ + datetime.strptime(str(row["trade_date"]), "%Y%m%d").date() + for row in rows + if row and row["trade_date"] + ] + + +def _estimate_total_workload( + trade_dates: Sequence[date], + universe: Optional[Sequence[str]], +) -> int: + """估算本次计算需要处理的证券数量,用于驱动进度条。""" + + trade_days = list(trade_dates) + if not trade_days: + return 0 + + scoped_universe = _normalize_universe(universe) + if scoped_universe: + return len(scoped_universe) * len(trade_days) + + start_str = min(trade_days).strftime("%Y%m%d") + end_str = max(trade_days).strftime("%Y%m%d") + with db_session(read_only=True) as conn: + row = conn.execute( + """ + SELECT COUNT(DISTINCT ts_code) AS cnt + FROM daily + WHERE trade_date BETWEEN ? AND ? + """, + (start_str, end_str), + ).fetchone() + universe_size = int(row["cnt"]) if row and row["cnt"] is not None else 0 + return universe_size * len(trade_days) + + def render_factor_calculation() -> None: """渲染因子计算页面。""" st.subheader("📊 因子计算") @@ -153,97 +230,55 @@ def render_factor_calculation() -> None: help="如果勾选,将跳过数据库中已存在的因子计算结果" ) - # 5. 同步计算函数 - def run_factor_calculation_sync(): - """同步执行因子计算""" - # 计算参数 - total_stocks = len(universe) if universe else len(_get_all_stocks()) - total_batches = len(selected_factors) - - try: - # 执行因子计算 - results = [] - for i, factor in enumerate(selected_factors): - # 更新批次进度 - factor_progress.update_progress( - current_securities=0, - current_batch=i+1, - message=f"正在计算因子: {factor.name}" - ) - - # 计算单个交易日的因子 - current_date = start_date - while current_date <= end_date: - try: - # 计算指定日期的因子 - daily_results = compute_factors( - current_date, - [factor], - ts_codes=universe, - skip_existing=skip_existing - ) - results.extend(daily_results) - - except Exception as e: - # 记录错误但不中断计算 - error_msg = f"计算因子 {factor.name} 在日期 {current_date} 时出错: {str(e)}" - print(f"ERROR: {error_msg}") - - current_date += timedelta(days=1) - - # 计算完成 - factor_progress.complete_calculation(f"因子计算完成!共计算 {len(results)} 条因子记录") - - return { - 'success': True, - 'results': results, - 'factors': [f.name for f in selected_factors], - 'date_range': f"{start_date} 至 {end_date}", - 'stock_count': len(set(r.ts_code for r in results)) if results else 0, - 'message': f"因子计算完成!共计算 {len(results)} 条因子记录" - } - - except Exception as e: - # 计算失败 - factor_progress.error_occurred(f"因子计算失败: {str(e)}") - return { - 'success': False, - 'error': str(e), - 'message': f"因子计算失败: {str(e)}" - } - - # 6. 开始计算按钮 + # 5. 开始计算按钮 if st.button("开始计算因子", disabled=not selected_factors): # 重置状态 - if 'factor_calculation_results' in st.session_state: - st.session_state.factor_calculation_results = None - if 'factor_calculation_error' in st.session_state: - st.session_state.factor_calculation_error = None - - # 初始化进度状态 - total_stocks = len(universe) if universe else len(_get_all_stocks()) + st.session_state.pop('factor_calculation_results', None) + st.session_state.pop('factor_calculation_error', None) + factor_progress.reset() + + scoped_universe = _normalize_universe(universe) or None + trade_dates = _get_trade_dates_between(start_date, end_date, scoped_universe) + if not trade_dates: + st.warning("所选时间窗口内无可用交易日数据,请先执行数据同步。") + return + + total_workload = _estimate_total_workload(trade_dates, scoped_universe) factor_progress.start_calculation( - total_securities=total_stocks, - total_batches=len(selected_factors) + total_securities=max(total_workload, 1), + total_batches=len(trade_dates), ) - - # 直接调用同步计算函数 - result = run_factor_calculation_sync() - - # 处理计算结果 - if result['success']: - st.session_state.factor_calculation_results = { - 'results': result['results'], - 'factors': result['factors'], - 'date_range': result['date_range'], - 'stock_count': result['stock_count'] - } - st.success("✅ 因子计算完成!") - else: - st.session_state.factor_calculation_error = result['error'] - st.error(f"❌ 因子计算失败: {result['error']}") + + with st.spinner("正在计算因子..."): + try: + results = compute_factor_range( + start=min(trade_dates), + end=max(trade_dates), + factors=selected_factors, + ts_codes=scoped_universe, + skip_existing=skip_existing, + ) + except Exception as exc: + LOGGER.exception("因子计算失败", extra={**LOG_EXTRA, "error": str(exc)}) + factor_progress.error_occurred(str(exc)) + st.session_state.factor_calculation_error = str(exc) + st.error(f"❌ 因子计算失败: {exc}") + else: + factor_progress.complete_calculation( + f"因子计算完成,共生成 {len(results)} 条因子记录" + ) + factor_names = [spec.name for spec in selected_factors] + stock_count = len({item.ts_code for item in results}) if results else 0 + st.session_state.factor_calculation_results = { + 'results': results, + 'factors': factor_names, + 'date_range': f"{trade_dates[0]} 至 {trade_dates[-1]}", + 'stock_count': stock_count, + 'trade_days': len(trade_dates), + } + st.success("✅ 因子计算完成!") - # 7. 显示计算结果 + # 6. 显示计算结果 if 'factor_calculation_results' in st.session_state and st.session_state.factor_calculation_results: results = st.session_state.factor_calculation_results @@ -255,7 +290,8 @@ def render_factor_calculation() -> None: with col2: st.metric("涉及股票数量", results['stock_count']) with col3: - st.metric("计算时间范围", results['date_range']) + st.metric("交易日数量", results.get('trade_days', 0)) + st.caption(f"时间范围:{results['date_range']}") # 显示计算详情 with st.expander("查看计算详情"): @@ -279,8 +315,6 @@ def render_factor_calculation() -> None: else: st.info("没有找到因子计算结果") - # 8. 移除异步线程检查逻辑(已改为同步模式) - - # 9. 显示错误信息 + # 7. 显示错误信息 if 'factor_calculation_error' in st.session_state and st.session_state.factor_calculation_error: - st.error(f"❌ 因子计算失败: {st.session_state.factor_calculation_error}") \ No newline at end of file + st.error(f"❌ 因子计算失败: {st.session_state.factor_calculation_error}") diff --git a/app/ui/views/stock_eval.py b/app/ui/views/stock_eval.py index 901db8d..18e8558 100644 --- a/app/ui/views/stock_eval.py +++ b/app/ui/views/stock_eval.py @@ -1,6 +1,7 @@ """股票筛选与评估视图。""" -from datetime import datetime, timedelta -from typing import Dict, List, Optional, Tuple +from datetime import date, datetime, timedelta +from typing import Dict, List, Optional +import json import numpy as np import pandas as pd @@ -15,11 +16,10 @@ from app.utils.db import db_session from app.utils.logging import get_logger -def _get_latest_trading_date() -> datetime.date: +def _get_latest_trading_date() -> date: """获取数据库中的最新交易日期""" - with db_session() as session: - # 获取当前日期的上一个有效交易日 - result = session.execute( + with db_session(read_only=True) as conn: + result = conn.execute( """ SELECT trade_date FROM daily_basic @@ -35,6 +35,19 @@ def _get_latest_trading_date() -> datetime.date: return datetime.strptime(str(result[0]), "%Y%m%d").date() return datetime.now().date() - timedelta(days=1) # 如果查询失败才返回昨天 + +def _normalize_universe(universe: Optional[List[str]]) -> List[str]: + """标准化股票代码列表,去重并转为大写。""" + + if not universe: + return [] + normalized: Dict[str, None] = {} + for code in universe: + candidate = (code or "").strip().upper() + if candidate and candidate not in normalized: + normalized[candidate] = None + return list(normalized.keys()) + def render_stock_evaluation() -> None: """渲染股票筛选与评估页面。""" LOGGER = get_logger(__name__) @@ -141,6 +154,9 @@ def render_stock_evaluation() -> None: index_code, end_date.strftime("%Y%m%d") ) + universe = _normalize_universe(universe) + if universe == []: + universe = None # 4. 评估结果 @@ -167,11 +183,12 @@ def render_stock_evaluation() -> None: ) st.session_state.evaluation_status = 'running' + st.session_state.pop('evaluation_error', None) results = [] for i, factor_name in enumerate(selected_factors): st.session_state.current_factor = factor_name - st.session_state.progress = (i / len(selected_factors)) * 100 + st.session_state.progress = ((i + 1) / len(selected_factors)) * 100 performance = evaluate_factor( factor_name, @@ -181,11 +198,11 @@ def render_stock_evaluation() -> None: ) results.append({ "因子": factor_name, - "IC均值": f"{performance.ic_mean:.4f}", - "RankIC均值": f"{performance.rank_ic_mean:.4f}", - "IC信息比率": f"{performance.ic_ir:.4f}", - "夏普比率": f"{performance.sharpe_ratio:.4f}" if performance.sharpe_ratio else "N/A", - "换手率": f"{performance.turnover_rate*100:.1f}%" if performance.turnover_rate else "N/A" + "IC均值": performance.ic_mean, + "RankIC均值": performance.rank_ic_mean, + "IC信息比率": performance.ic_ir, + "夏普比率": performance.sharpe_ratio, + "换手率": performance.turnover_rate, }) st.session_state.evaluation_results = results @@ -221,71 +238,89 @@ def render_stock_evaluation() -> None: st.markdown("##### 因子评估结果") result_df = pd.DataFrame(results) - st.dataframe( - result_df, - hide_index=True, - width="stretch" - ) + if not result_df.empty: + display_df = result_df.copy() + for col in ["IC均值", "RankIC均值", "IC信息比率"]: + if col in display_df: + display_df[col] = display_df[col].map(lambda v: f"{v:.4f}") + if "夏普比率" in display_df: + display_df["夏普比率"] = display_df["夏普比率"].map( + lambda v: "N/A" if v is None else f"{v:.4f}" + ) + if "换手率" in display_df: + display_df["换手率"] = display_df["换手率"].map( + lambda v: "N/A" if v is None else f"{v * 100:.1f}%" + ) + st.dataframe( + display_df, + hide_index=True, + width="stretch" + ) + else: + st.info("未产生任何因子评估结果。") # 绘制IC均值分布 - ic_means = [float(r["IC均值"]) for r in results] + ic_means = result_df["IC均值"].astype(float).tolist() if not result_df.empty else [] chart_df = pd.DataFrame({ "因子": [r["因子"] for r in results], "IC均值": ic_means }) st.bar_chart(chart_df.set_index("因子")) - - # 生成股票评分 + + if not ic_means: + st.info("暂无足够的 IC 数据,无法生成股票评分。") + return + with st.spinner("正在生成股票评分..."): - # 使用IC均值作为权重,但如果IC均值全为零,则使用均匀分布 if all(mean == 0 for mean in ic_means): factor_weights = [1.0 / len(ic_means)] * len(ic_means) LOGGER.info("所有因子IC均值均为零,使用均匀权重", extra=LOG_EXTRA) else: - # 将IC均值归一化为权重 - abs_sum = sum(abs(m) for m in ic_means) + abs_sum = sum(abs(m) for m in ic_means) or 1.0 factor_weights = [m / abs_sum for m in ic_means] LOGGER.info("使用IC均值作为权重: %s", factor_weights, extra=LOG_EXTRA) - + scores = _calculate_stock_scores( universe, selected_factors, end_date, factor_weights ) - - if scores: - st.markdown("##### 股票综合评分 (Top 20)") - score_df = pd.DataFrame(scores).sort_values( - "综合评分", - ascending=False - ).head(20) - st.dataframe( - score_df, - hide_index=True, - width="stretch" - ) - - # 添加入池功能 - if st.button("将Top 20股票加入股票池"): - _add_to_stock_pool( - score_df["股票代码"].tolist(), - end_date - ) - st.success("已成功将选中股票加入股票池!") + + if scores: + st.markdown("##### 股票综合评分 (Top 20)") + score_df = pd.DataFrame(scores).sort_values( + "综合评分", + ascending=False + ) + top_df = score_df.head(20).reset_index(drop=True) + display_scores = top_df.copy() + display_scores["综合评分"] = display_scores["综合评分"].map(lambda v: f"{v:.4f}") + st.dataframe( + display_scores, + hide_index=True, + width="stretch" + ) + + if st.button("将Top 20股票加入股票池"): + _add_to_stock_pool(top_df, end_date) + st.success("已成功将选中股票加入股票池!") + else: + st.info("无法根据当前因子权重生成有效的股票评分结果。") def _calculate_stock_scores( universe: Optional[List[str]], factors: List[str], - eval_date: datetime.date, + eval_date: date, factor_weights: List[float] -) -> List[Dict[str, str]]: +) -> List[Dict[str, object]]: """计算股票的综合评分。""" LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "stock_evaluation"} broker = DataBroker() + trade_date_str = eval_date.strftime("%Y%m%d") # 记录评估开始 LOGGER.info( @@ -297,7 +332,7 @@ def _calculate_stock_scores( ) # 标准化权重 - weights = np.array(factor_weights) + weights = np.array(factor_weights, dtype=float) abs_sum = np.sum(np.abs(weights)) if abs_sum > 0: # 避免除以零 weights = weights / abs_sum @@ -306,7 +341,10 @@ def _calculate_stock_scores( weights = np.ones_like(weights) / len(weights) # 获取所有股票的因子值 - stocks = universe or broker.get_all_stocks(eval_date.strftime("%Y%m%d")) + stocks = universe or broker.get_all_stocks(trade_date_str) + if not stocks: + LOGGER.warning("股票列表为空,无法生成评分", extra=LOG_EXTRA) + return [] # 记录股票列表信息 LOGGER.info( @@ -320,42 +358,54 @@ def _calculate_stock_scores( evaluated_count = 0 skipped_count = 0 + factor_fields = [f"factors.{name}" for name in factors] for ts_code in stocks: - # 检查数据是否充分 - if not check_data_sufficiency(ts_code, eval_date.strftime("%Y%m%d")): + if not check_data_sufficiency(ts_code, trade_date_str): skipped_count += 1 continue - - # 获取股票信息 - info = broker.get_stock_info(ts_code) + + latest_payload = broker.fetch_latest( + ts_code, + trade_date_str, + factor_fields, + auto_refresh=False, + ) + + if not latest_payload: + skipped_count += 1 + continue + + factor_values: List[float] = [] + missing = False + for field in factor_fields: + value = latest_payload.get(field) + if value is None: + missing = True + break + try: + factor_values.append(float(value)) + except (TypeError, ValueError): + missing = True + break + + if missing or len(factor_values) != len(factors): + skipped_count += 1 + continue + + info = broker.get_stock_info(ts_code, trade_date_str) if not info: skipped_count += 1 continue - - # 获取因子值 - factor_values = [] - for factor in factors: - value = broker.fetch_latest_factor(ts_code, factor, eval_date) - if value is None: - skipped_count += 1 - break - factor_values.append(value) - - # 检查是否所有因子值都已获取 - if len(factor_values) != len(factors): - skipped_count += 1 - continue - - # 计算综合评分 - score = np.dot(factor_values, weights) + + score = float(np.dot(factor_values, weights)) evaluated_count += 1 results.append({ "股票代码": ts_code, "股票名称": info.get("name", ""), "行业": info.get("industry", ""), - "综合评分": f"{score:.4f}" + "综合评分": score, }) # 记录评估完成信息 @@ -372,36 +422,51 @@ def _calculate_stock_scores( def _add_to_stock_pool( - ts_codes: List[str], - eval_date: datetime.date + score_df: pd.DataFrame, + eval_date: date ) -> None: - """将股票添加到股票池。""" - with db_session() as session: - # 删除已有记录 - session.execute( - """ - DELETE FROM stock_pool - WHERE entry_date = :entry_date - """, - {"entry_date": eval_date} - ) - - # 插入新记录 - values = [ + """将股票评分结果写入投资池。""" + + trade_date = eval_date.strftime("%Y%m%d") + payload: List[tuple] = [] + ranked_df = score_df.reset_index(drop=True) + + for rank, row in ranked_df.iterrows(): + tags = json.dumps(["stock_evaluation", "top20"], ensure_ascii=False) + metadata = json.dumps( { - "ts_code": code, - "entry_date": eval_date, - "entry_reason": "factor_evaluation" - } - for code in ts_codes - ] - - session.execute( - """ - INSERT INTO stock_pool (ts_code, entry_date, entry_reason) - VALUES (:ts_code, :entry_date, :entry_reason) - """, - values + "source": "stock_evaluation", + "rank": rank + 1, + "score": float(row["综合评分"]), + }, + ensure_ascii=False, ) - - session.commit() \ No newline at end of file + payload.append( + ( + trade_date, + row["股票代码"], + float(row["综合评分"]), + "candidate", + "factor_evaluation_top20", + tags, + metadata, + ) + ) + + with db_session() as conn: + conn.execute("DELETE FROM investment_pool WHERE trade_date = ?", (trade_date,)) + if payload: + conn.executemany( + """ + INSERT INTO investment_pool ( + trade_date, + ts_code, + score, + status, + rationale, + tags, + metadata + ) VALUES (?, ?, ?, ?, ?, ?, ?) + """, + payload, + )