refactor factor calculation and improve progress tracking

This commit is contained in:
Your Name 2025-10-10 21:10:51 +08:00
parent 44adc836fa
commit 43c70f3f7f
7 changed files with 383 additions and 251 deletions

View File

@ -151,5 +151,5 @@ Streamlit `自检测试` 页签提供:
TODO TODO
1. 在选股时,因子都已经提前算好,不需要再计算了,直接用就行。 1. 在选股时,因子都已经提前算好,不需要再计算了,直接用就行。
2. 因子计算的公式再确认下 2. 因子计算的公式再确认下
3. 审查整个项目的代码逻辑,从main.py开始逐字逐句检查。如一些重复的检查可以去掉未实现的功能请标记TODO并给出实现思路错误的、低效率的调用请修正代码结构性的问题请指出 3. 审查整个项目的代码逻辑,从app/ui/streamlit_app.py开始逐字逐句检查。如一些重复的安全检查可以去掉;明显果实的临时性代码请删除掉;未实现的功能请标记TODO并给出实现思路错误的、低效率的调用请修正代码结构性的问题请指出并尝试修正;复杂不清晰的代码结构请尝试重构;
4. 梳理整个项目的所有业务逻辑。针对每个业务,从业务实现角度评估代码功能是否存在问题,是否需要优化,是否需要重构。 4. 梳理整个项目的所有业务逻辑。针对每个业务,从业务实现角度评估代码功能是否存在问题,是否需要优化,是否需要重构。

View File

@ -9,7 +9,8 @@ from app.features.factors import (
DEFAULT_FACTORS, DEFAULT_FACTORS,
FactorResult, FactorResult,
FactorSpec, FactorSpec,
compute_factor_range compute_factor_range,
lookup_factor_spec,
) )
from app.utils.data_access import DataBroker from app.utils.data_access import DataBroker
from app.utils.logging import get_logger from app.utils.logging import get_logger
@ -90,14 +91,14 @@ def evaluate_factor(
# ) # )
try: try:
# 计算因子值 spec = lookup_factor_spec(factor_name) or FactorSpec(factor_name, 0)
# 设置 skip_existing=False确保即使因子已存在也会重新计算
factor_results = compute_factor_range( factor_results = compute_factor_range(
start_date, start_date,
end_date, end_date,
factors=[FactorSpec(factor_name, 0)], factors=[spec],
ts_codes=universe, ts_codes=universe,
skip_existing=False skip_existing=True,
) )
# 因子计算完成在异步线程中不直接访问factor_progress # 因子计算完成在异步线程中不直接访问factor_progress

View File

@ -90,6 +90,17 @@ DEFAULT_FACTORS: List[FactorSpec] = [
FactorSpec("risk_penalty", 0), # 风险惩罚因子 FactorSpec("risk_penalty", 0), # 风险惩罚因子
] ]
_FACTOR_SPEC_MAP: Dict[str, FactorSpec] = {spec.name: spec for spec in DEFAULT_FACTORS}
def lookup_factor_spec(name: str) -> Optional[FactorSpec]:
"""Return a copy of the registered ``FactorSpec`` for ``name`` if available."""
base = _FACTOR_SPEC_MAP.get(name)
if base is None:
return None
return FactorSpec(name=base.name, window=base.window)
def compute_factors( def compute_factors(
trade_date: date, trade_date: date,
@ -304,30 +315,33 @@ def _existing_factor_codes_with_factors(trade_date: str, factor_names: List[str]
if not factor_names: if not factor_names:
return {} return {}
# 构建检查条件 valid_names = [
conditions = [] name
for name in factor_names: for name in factor_names
conditions.append(f"json_extract(factors, '$.{name}') IS NOT NULL") if isinstance(name, str) and _IDENTIFIER_RE.match(name)
condition_str = " AND ".join(conditions) ]
if not valid_names:
# 构建SQL查询 return {}
query = """
SELECT ts_code
FROM factors
WHERE trade_date = ?
AND """ + condition_str + """
GROUP BY ts_code
"""
with db_session(read_only=True) as conn: with db_session(read_only=True) as conn:
columns = {
row["name"]
for row in conn.execute("PRAGMA table_info(factors)").fetchall()
}
selected = [name for name in valid_names if name in columns]
if not selected:
return {}
predicates = " AND ".join(f"{col} IS NOT NULL" for col in selected)
query = (
"SELECT ts_code FROM factors "
"WHERE trade_date = ? AND "
f"{predicates} "
"GROUP BY ts_code"
)
rows = conn.execute(query, (trade_date,)).fetchall() rows = conn.execute(query, (trade_date,)).fetchall()
# 返回结果 return {row["ts_code"]: True for row in rows if row and row["ts_code"]}
result = {}
for row in rows:
result[row["ts_code"]] = True
return result
def _list_trade_dates( def _list_trade_dates(

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
import time
import streamlit as st import streamlit as st
@ -25,7 +26,7 @@ class FactorProgressState:
'status': 'idle', # idle, running, completed, error 'status': 'idle', # idle, running, completed, error
'message': '', 'message': '',
'start_time': None, 'start_time': None,
'elapsed_time': 0 'elapsed_time': 0.0,
} }
def start_calculation(self, total_securities: int, total_batches: int) -> None: def start_calculation(self, total_securities: int, total_batches: int) -> None:
@ -35,16 +36,17 @@ class FactorProgressState:
total_securities: 总证券数量 total_securities: 总证券数量
total_batches: 总批次数 total_batches: 总批次数
""" """
now = time.time()
st.session_state.factor_progress.update({ st.session_state.factor_progress.update({
'current': 0, 'current': 0,
'total': total_securities, 'total': max(total_securities, 0),
'percentage': 0.0, 'percentage': 0.0,
'current_batch': 0, 'current_batch': 0,
'total_batches': total_batches, 'total_batches': max(total_batches, 0),
'status': 'running', 'status': 'running',
'message': '开始因子计算...', 'message': '开始因子计算...',
'start_time': st.session_state.get('factor_progress', {}).get('start_time'), 'start_time': now,
'elapsed_time': 0 'elapsed_time': 0.0,
}) })
def update_progress(self, current_securities: int, current_batch: int, def update_progress(self, current_securities: int, current_batch: int,
@ -59,18 +61,26 @@ class FactorProgressState:
progress = st.session_state.factor_progress progress = st.session_state.factor_progress
# 计算百分比 # 计算百分比
if progress['total'] > 0: total = progress.get('total', 0) or 0
percentage = (current_securities / progress['total']) * 100 if total > 0:
percentage = (current_securities / total) * 100
else: else:
percentage = 0.0 percentage = 0.0
start_time = progress.get('start_time')
if isinstance(start_time, (int, float)):
elapsed = max(0.0, time.time() - start_time)
else:
elapsed = 0.0
# 更新状态 # 更新状态
progress.update({ progress.update({
'current': current_securities, 'current': current_securities,
'current_batch': current_batch, 'current_batch': current_batch,
'percentage': percentage, 'percentage': percentage,
'message': message or f'处理批次 {current_batch}/{progress["total_batches"]}', 'message': message or f'处理批次 {current_batch}/{progress["total_batches"] or 1}',
'status': 'running' 'status': 'running',
'elapsed_time': elapsed,
}) })
def complete_calculation(self, message: str = '因子计算完成') -> None: def complete_calculation(self, message: str = '因子计算完成') -> None:
@ -80,11 +90,17 @@ class FactorProgressState:
message: 完成消息 message: 完成消息
""" """
progress = st.session_state.factor_progress progress = st.session_state.factor_progress
start_time = progress.get('start_time')
if isinstance(start_time, (int, float)):
elapsed = max(0.0, time.time() - start_time)
else:
elapsed = progress.get('elapsed_time', 0.0) or 0.0
progress.update({ progress.update({
'current': progress['total'], 'current': progress.get('total', 0),
'percentage': 100.0, 'percentage': 100.0 if progress.get('total', 0) else progress.get('percentage', 0.0),
'status': 'completed', 'status': 'completed',
'message': message 'message': message,
'elapsed_time': elapsed,
}) })
def error_occurred(self, error_message: str) -> None: def error_occurred(self, error_message: str) -> None:
@ -93,9 +109,16 @@ class FactorProgressState:
Args: Args:
error_message: 错误消息 error_message: 错误消息
""" """
st.session_state.factor_progress.update({ progress = st.session_state.factor_progress
start_time = progress.get('start_time')
if isinstance(start_time, (int, float)):
elapsed = max(0.0, time.time() - start_time)
else:
elapsed = progress.get('elapsed_time', 0.0) or 0.0
progress.update({
'status': 'error', 'status': 'error',
'message': f'错误: {error_message}' 'message': f'错误: {error_message}',
'elapsed_time': elapsed,
}) })
def get_progress_info(self) -> Dict[str, Any]: def get_progress_info(self) -> Dict[str, Any]:
@ -118,7 +141,7 @@ class FactorProgressState:
'status': 'idle', 'status': 'idle',
'message': '', 'message': '',
'start_time': None, 'start_time': None,
'elapsed_time': 0 'elapsed_time': 0.0,
} }

View File

@ -2,16 +2,11 @@
from __future__ import annotations from __future__ import annotations
import json import json
import uuid
from dataclasses import asdict
from datetime import date, datetime from datetime import date, datetime
from typing import Dict, List, Optional from typing import Dict, List, Optional
import numpy as np
import pandas as pd import pandas as pd
import plotly.express as px import plotly.express as px
import requests
from requests.exceptions import RequestException
import streamlit as st import streamlit as st
from app.agents.base import AgentContext from app.agents.base import AgentContext

View File

@ -1,19 +1,20 @@
"""因子计算页面。""" """因子计算页面。"""
from datetime import datetime, timedelta from datetime import date, datetime, timedelta
from typing import List, Optional from typing import List, Optional, Sequence
import streamlit as st import streamlit as st
from app.features.factors import compute_factors, DEFAULT_FACTORS, FactorSpec from app.features.factors import DEFAULT_FACTORS, FactorSpec, compute_factor_range
from app.ui.progress_state import factor_progress from app.ui.progress_state import factor_progress
from app.ui.shared import LOGGER, LOG_EXTRA
from app.utils.data_access import DataBroker from app.utils.data_access import DataBroker
from app.utils.db import db_session from app.utils.db import db_session
def _get_latest_trading_date() -> datetime.date: def _get_latest_trading_date() -> datetime.date:
"""获取数据库中的最新交易日期""" """获取数据库中的最新交易日期"""
with db_session() as session: with db_session(read_only=True) as conn:
result = session.execute( result = conn.execute(
""" """
SELECT trade_date SELECT trade_date
FROM daily_basic FROM daily_basic
@ -34,9 +35,9 @@ def _get_all_stocks() -> List[str]:
"""获取所有股票代码""" """获取所有股票代码"""
try: try:
# 从daily表获取所有股票代码 # 从daily表获取所有股票代码
with db_session() as session: with db_session(read_only=True) as conn:
latest_date = _get_latest_trading_date() latest_date = _get_latest_trading_date()
result = session.execute( result = conn.execute(
""" """
SELECT DISTINCT ts_code SELECT DISTINCT ts_code
FROM daily FROM daily
@ -45,12 +46,88 @@ def _get_all_stocks() -> List[str]:
{"trade_date": latest_date.strftime("%Y%m%d")} {"trade_date": latest_date.strftime("%Y%m%d")}
).fetchall() ).fetchall()
return [row[0] for row in result] if result else [] return [row["ts_code"] for row in result if row and row["ts_code"]] if result else []
except Exception as e: except Exception as exc:
st.error(f"获取股票列表失败: {str(e)}") LOGGER.exception("获取股票列表失败", extra={**LOG_EXTRA, "error": str(exc)})
st.error(f"获取股票列表失败: {exc}")
return [] return []
def _normalize_universe(universe: Optional[Sequence[str]]) -> List[str]:
"""去重并规范股票代码格式。"""
if not universe:
return []
seen: dict[str, None] = {}
for code in universe:
normalized = code.strip().upper()
if normalized and normalized not in seen:
seen[normalized] = None
return list(seen.keys())
def _get_trade_dates_between(
start: date,
end: date,
universe: Optional[Sequence[str]] = None,
) -> List[date]:
"""获取区间内存在行情数据的交易日期列表。"""
if end < start:
return []
start_str = start.strftime("%Y%m%d")
end_str = end.strftime("%Y%m%d")
params: List[str] = [start_str, end_str]
query = (
"SELECT DISTINCT trade_date FROM daily "
"WHERE trade_date BETWEEN ? AND ?"
)
scoped_universe = _normalize_universe(universe)
if scoped_universe:
placeholders = ", ".join("?" for _ in scoped_universe)
query += f" AND ts_code IN ({placeholders})"
params.extend(scoped_universe)
query += " ORDER BY trade_date"
with db_session(read_only=True) as conn:
rows = conn.execute(query, params).fetchall()
return [
datetime.strptime(str(row["trade_date"]), "%Y%m%d").date()
for row in rows
if row and row["trade_date"]
]
def _estimate_total_workload(
trade_dates: Sequence[date],
universe: Optional[Sequence[str]],
) -> int:
"""估算本次计算需要处理的证券数量,用于驱动进度条。"""
trade_days = list(trade_dates)
if not trade_days:
return 0
scoped_universe = _normalize_universe(universe)
if scoped_universe:
return len(scoped_universe) * len(trade_days)
start_str = min(trade_days).strftime("%Y%m%d")
end_str = max(trade_days).strftime("%Y%m%d")
with db_session(read_only=True) as conn:
row = conn.execute(
"""
SELECT COUNT(DISTINCT ts_code) AS cnt
FROM daily
WHERE trade_date BETWEEN ? AND ?
""",
(start_str, end_str),
).fetchone()
universe_size = int(row["cnt"]) if row and row["cnt"] is not None else 0
return universe_size * len(trade_days)
def render_factor_calculation() -> None: def render_factor_calculation() -> None:
"""渲染因子计算页面。""" """渲染因子计算页面。"""
st.subheader("📊 因子计算") st.subheader("📊 因子计算")
@ -153,97 +230,55 @@ def render_factor_calculation() -> None:
help="如果勾选,将跳过数据库中已存在的因子计算结果" help="如果勾选,将跳过数据库中已存在的因子计算结果"
) )
# 5. 同步计算函数 # 5. 开始计算按钮
def run_factor_calculation_sync():
"""同步执行因子计算"""
# 计算参数
total_stocks = len(universe) if universe else len(_get_all_stocks())
total_batches = len(selected_factors)
try:
# 执行因子计算
results = []
for i, factor in enumerate(selected_factors):
# 更新批次进度
factor_progress.update_progress(
current_securities=0,
current_batch=i+1,
message=f"正在计算因子: {factor.name}"
)
# 计算单个交易日的因子
current_date = start_date
while current_date <= end_date:
try:
# 计算指定日期的因子
daily_results = compute_factors(
current_date,
[factor],
ts_codes=universe,
skip_existing=skip_existing
)
results.extend(daily_results)
except Exception as e:
# 记录错误但不中断计算
error_msg = f"计算因子 {factor.name} 在日期 {current_date} 时出错: {str(e)}"
print(f"ERROR: {error_msg}")
current_date += timedelta(days=1)
# 计算完成
factor_progress.complete_calculation(f"因子计算完成!共计算 {len(results)} 条因子记录")
return {
'success': True,
'results': results,
'factors': [f.name for f in selected_factors],
'date_range': f"{start_date}{end_date}",
'stock_count': len(set(r.ts_code for r in results)) if results else 0,
'message': f"因子计算完成!共计算 {len(results)} 条因子记录"
}
except Exception as e:
# 计算失败
factor_progress.error_occurred(f"因子计算失败: {str(e)}")
return {
'success': False,
'error': str(e),
'message': f"因子计算失败: {str(e)}"
}
# 6. 开始计算按钮
if st.button("开始计算因子", disabled=not selected_factors): if st.button("开始计算因子", disabled=not selected_factors):
# 重置状态 # 重置状态
if 'factor_calculation_results' in st.session_state: st.session_state.pop('factor_calculation_results', None)
st.session_state.factor_calculation_results = None st.session_state.pop('factor_calculation_error', None)
if 'factor_calculation_error' in st.session_state: factor_progress.reset()
st.session_state.factor_calculation_error = None
# 初始化进度状态 scoped_universe = _normalize_universe(universe) or None
total_stocks = len(universe) if universe else len(_get_all_stocks()) trade_dates = _get_trade_dates_between(start_date, end_date, scoped_universe)
if not trade_dates:
st.warning("所选时间窗口内无可用交易日数据,请先执行数据同步。")
return
total_workload = _estimate_total_workload(trade_dates, scoped_universe)
factor_progress.start_calculation( factor_progress.start_calculation(
total_securities=total_stocks, total_securities=max(total_workload, 1),
total_batches=len(selected_factors) total_batches=len(trade_dates),
) )
# 直接调用同步计算函数 with st.spinner("正在计算因子..."):
result = run_factor_calculation_sync() try:
results = compute_factor_range(
# 处理计算结果 start=min(trade_dates),
if result['success']: end=max(trade_dates),
factors=selected_factors,
ts_codes=scoped_universe,
skip_existing=skip_existing,
)
except Exception as exc:
LOGGER.exception("因子计算失败", extra={**LOG_EXTRA, "error": str(exc)})
factor_progress.error_occurred(str(exc))
st.session_state.factor_calculation_error = str(exc)
st.error(f"❌ 因子计算失败: {exc}")
else:
factor_progress.complete_calculation(
f"因子计算完成,共生成 {len(results)} 条因子记录"
)
factor_names = [spec.name for spec in selected_factors]
stock_count = len({item.ts_code for item in results}) if results else 0
st.session_state.factor_calculation_results = { st.session_state.factor_calculation_results = {
'results': result['results'], 'results': results,
'factors': result['factors'], 'factors': factor_names,
'date_range': result['date_range'], 'date_range': f"{trade_dates[0]}{trade_dates[-1]}",
'stock_count': result['stock_count'] 'stock_count': stock_count,
'trade_days': len(trade_dates),
} }
st.success("✅ 因子计算完成!") st.success("✅ 因子计算完成!")
else:
st.session_state.factor_calculation_error = result['error']
st.error(f"❌ 因子计算失败: {result['error']}")
# 7. 显示计算结果 # 6. 显示计算结果
if 'factor_calculation_results' in st.session_state and st.session_state.factor_calculation_results: if 'factor_calculation_results' in st.session_state and st.session_state.factor_calculation_results:
results = st.session_state.factor_calculation_results results = st.session_state.factor_calculation_results
@ -255,7 +290,8 @@ def render_factor_calculation() -> None:
with col2: with col2:
st.metric("涉及股票数量", results['stock_count']) st.metric("涉及股票数量", results['stock_count'])
with col3: with col3:
st.metric("计算时间范围", results['date_range']) st.metric("交易日数量", results.get('trade_days', 0))
st.caption(f"时间范围:{results['date_range']}")
# 显示计算详情 # 显示计算详情
with st.expander("查看计算详情"): with st.expander("查看计算详情"):
@ -279,8 +315,6 @@ def render_factor_calculation() -> None:
else: else:
st.info("没有找到因子计算结果") st.info("没有找到因子计算结果")
# 8. 移除异步线程检查逻辑(已改为同步模式) # 7. 显示错误信息
# 9. 显示错误信息
if 'factor_calculation_error' in st.session_state and st.session_state.factor_calculation_error: if 'factor_calculation_error' in st.session_state and st.session_state.factor_calculation_error:
st.error(f"❌ 因子计算失败: {st.session_state.factor_calculation_error}") st.error(f"❌ 因子计算失败: {st.session_state.factor_calculation_error}")

View File

@ -1,6 +1,7 @@
"""股票筛选与评估视图。""" """股票筛选与评估视图。"""
from datetime import datetime, timedelta from datetime import date, datetime, timedelta
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional
import json
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -15,11 +16,10 @@ from app.utils.db import db_session
from app.utils.logging import get_logger from app.utils.logging import get_logger
def _get_latest_trading_date() -> datetime.date: def _get_latest_trading_date() -> date:
"""获取数据库中的最新交易日期""" """获取数据库中的最新交易日期"""
with db_session() as session: with db_session(read_only=True) as conn:
# 获取当前日期的上一个有效交易日 result = conn.execute(
result = session.execute(
""" """
SELECT trade_date SELECT trade_date
FROM daily_basic FROM daily_basic
@ -35,6 +35,19 @@ def _get_latest_trading_date() -> datetime.date:
return datetime.strptime(str(result[0]), "%Y%m%d").date() return datetime.strptime(str(result[0]), "%Y%m%d").date()
return datetime.now().date() - timedelta(days=1) # 如果查询失败才返回昨天 return datetime.now().date() - timedelta(days=1) # 如果查询失败才返回昨天
def _normalize_universe(universe: Optional[List[str]]) -> List[str]:
"""标准化股票代码列表,去重并转为大写。"""
if not universe:
return []
normalized: Dict[str, None] = {}
for code in universe:
candidate = (code or "").strip().upper()
if candidate and candidate not in normalized:
normalized[candidate] = None
return list(normalized.keys())
def render_stock_evaluation() -> None: def render_stock_evaluation() -> None:
"""渲染股票筛选与评估页面。""" """渲染股票筛选与评估页面。"""
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
@ -141,6 +154,9 @@ def render_stock_evaluation() -> None:
index_code, index_code,
end_date.strftime("%Y%m%d") end_date.strftime("%Y%m%d")
) )
universe = _normalize_universe(universe)
if universe == []:
universe = None
# 4. 评估结果 # 4. 评估结果
@ -167,11 +183,12 @@ def render_stock_evaluation() -> None:
) )
st.session_state.evaluation_status = 'running' st.session_state.evaluation_status = 'running'
st.session_state.pop('evaluation_error', None)
results = [] results = []
for i, factor_name in enumerate(selected_factors): for i, factor_name in enumerate(selected_factors):
st.session_state.current_factor = factor_name st.session_state.current_factor = factor_name
st.session_state.progress = (i / len(selected_factors)) * 100 st.session_state.progress = ((i + 1) / len(selected_factors)) * 100
performance = evaluate_factor( performance = evaluate_factor(
factor_name, factor_name,
@ -181,11 +198,11 @@ def render_stock_evaluation() -> None:
) )
results.append({ results.append({
"因子": factor_name, "因子": factor_name,
"IC均值": f"{performance.ic_mean:.4f}", "IC均值": performance.ic_mean,
"RankIC均值": f"{performance.rank_ic_mean:.4f}", "RankIC均值": performance.rank_ic_mean,
"IC信息比率": f"{performance.ic_ir:.4f}", "IC信息比率": performance.ic_ir,
"夏普比率": f"{performance.sharpe_ratio:.4f}" if performance.sharpe_ratio else "N/A", "夏普比率": performance.sharpe_ratio,
"换手率": f"{performance.turnover_rate*100:.1f}%" if performance.turnover_rate else "N/A" "换手率": performance.turnover_rate,
}) })
st.session_state.evaluation_results = results st.session_state.evaluation_results = results
@ -221,29 +238,45 @@ def render_stock_evaluation() -> None:
st.markdown("##### 因子评估结果") st.markdown("##### 因子评估结果")
result_df = pd.DataFrame(results) result_df = pd.DataFrame(results)
if not result_df.empty:
display_df = result_df.copy()
for col in ["IC均值", "RankIC均值", "IC信息比率"]:
if col in display_df:
display_df[col] = display_df[col].map(lambda v: f"{v:.4f}")
if "夏普比率" in display_df:
display_df["夏普比率"] = display_df["夏普比率"].map(
lambda v: "N/A" if v is None else f"{v:.4f}"
)
if "换手率" in display_df:
display_df["换手率"] = display_df["换手率"].map(
lambda v: "N/A" if v is None else f"{v * 100:.1f}%"
)
st.dataframe( st.dataframe(
result_df, display_df,
hide_index=True, hide_index=True,
width="stretch" width="stretch"
) )
else:
st.info("未产生任何因子评估结果。")
# 绘制IC均值分布 # 绘制IC均值分布
ic_means = [float(r["IC均值"]) for r in results] ic_means = result_df["IC均值"].astype(float).tolist() if not result_df.empty else []
chart_df = pd.DataFrame({ chart_df = pd.DataFrame({
"因子": [r["因子"] for r in results], "因子": [r["因子"] for r in results],
"IC均值": ic_means "IC均值": ic_means
}) })
st.bar_chart(chart_df.set_index("因子")) st.bar_chart(chart_df.set_index("因子"))
# 生成股票评分 if not ic_means:
st.info("暂无足够的 IC 数据,无法生成股票评分。")
return
with st.spinner("正在生成股票评分..."): with st.spinner("正在生成股票评分..."):
# 使用IC均值作为权重但如果IC均值全为零则使用均匀分布
if all(mean == 0 for mean in ic_means): if all(mean == 0 for mean in ic_means):
factor_weights = [1.0 / len(ic_means)] * len(ic_means) factor_weights = [1.0 / len(ic_means)] * len(ic_means)
LOGGER.info("所有因子IC均值均为零使用均匀权重", extra=LOG_EXTRA) LOGGER.info("所有因子IC均值均为零使用均匀权重", extra=LOG_EXTRA)
else: else:
# 将IC均值归一化为权重 abs_sum = sum(abs(m) for m in ic_means) or 1.0
abs_sum = sum(abs(m) for m in ic_means)
factor_weights = [m / abs_sum for m in ic_means] factor_weights = [m / abs_sum for m in ic_means]
LOGGER.info("使用IC均值作为权重: %s", factor_weights, extra=LOG_EXTRA) LOGGER.info("使用IC均值作为权重: %s", factor_weights, extra=LOG_EXTRA)
@ -259,33 +292,35 @@ def render_stock_evaluation() -> None:
score_df = pd.DataFrame(scores).sort_values( score_df = pd.DataFrame(scores).sort_values(
"综合评分", "综合评分",
ascending=False ascending=False
).head(20) )
top_df = score_df.head(20).reset_index(drop=True)
display_scores = top_df.copy()
display_scores["综合评分"] = display_scores["综合评分"].map(lambda v: f"{v:.4f}")
st.dataframe( st.dataframe(
score_df, display_scores,
hide_index=True, hide_index=True,
width="stretch" width="stretch"
) )
# 添加入池功能
if st.button("将Top 20股票加入股票池"): if st.button("将Top 20股票加入股票池"):
_add_to_stock_pool( _add_to_stock_pool(top_df, end_date)
score_df["股票代码"].tolist(),
end_date
)
st.success("已成功将选中股票加入股票池!") st.success("已成功将选中股票加入股票池!")
else:
st.info("无法根据当前因子权重生成有效的股票评分结果。")
def _calculate_stock_scores( def _calculate_stock_scores(
universe: Optional[List[str]], universe: Optional[List[str]],
factors: List[str], factors: List[str],
eval_date: datetime.date, eval_date: date,
factor_weights: List[float] factor_weights: List[float]
) -> List[Dict[str, str]]: ) -> List[Dict[str, object]]:
"""计算股票的综合评分。""" """计算股票的综合评分。"""
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "stock_evaluation"} LOG_EXTRA = {"stage": "stock_evaluation"}
broker = DataBroker() broker = DataBroker()
trade_date_str = eval_date.strftime("%Y%m%d")
# 记录评估开始 # 记录评估开始
LOGGER.info( LOGGER.info(
@ -297,7 +332,7 @@ def _calculate_stock_scores(
) )
# 标准化权重 # 标准化权重
weights = np.array(factor_weights) weights = np.array(factor_weights, dtype=float)
abs_sum = np.sum(np.abs(weights)) abs_sum = np.sum(np.abs(weights))
if abs_sum > 0: # 避免除以零 if abs_sum > 0: # 避免除以零
weights = weights / abs_sum weights = weights / abs_sum
@ -306,7 +341,10 @@ def _calculate_stock_scores(
weights = np.ones_like(weights) / len(weights) weights = np.ones_like(weights) / len(weights)
# 获取所有股票的因子值 # 获取所有股票的因子值
stocks = universe or broker.get_all_stocks(eval_date.strftime("%Y%m%d")) stocks = universe or broker.get_all_stocks(trade_date_str)
if not stocks:
LOGGER.warning("股票列表为空,无法生成评分", extra=LOG_EXTRA)
return []
# 记录股票列表信息 # 记录股票列表信息
LOGGER.info( LOGGER.info(
@ -320,42 +358,54 @@ def _calculate_stock_scores(
evaluated_count = 0 evaluated_count = 0
skipped_count = 0 skipped_count = 0
factor_fields = [f"factors.{name}" for name in factors]
for ts_code in stocks: for ts_code in stocks:
# 检查数据是否充分 if not check_data_sufficiency(ts_code, trade_date_str):
if not check_data_sufficiency(ts_code, eval_date.strftime("%Y%m%d")):
skipped_count += 1 skipped_count += 1
continue continue
# 获取股票信息 latest_payload = broker.fetch_latest(
info = broker.get_stock_info(ts_code) ts_code,
trade_date_str,
factor_fields,
auto_refresh=False,
)
if not latest_payload:
skipped_count += 1
continue
factor_values: List[float] = []
missing = False
for field in factor_fields:
value = latest_payload.get(field)
if value is None:
missing = True
break
try:
factor_values.append(float(value))
except (TypeError, ValueError):
missing = True
break
if missing or len(factor_values) != len(factors):
skipped_count += 1
continue
info = broker.get_stock_info(ts_code, trade_date_str)
if not info: if not info:
skipped_count += 1 skipped_count += 1
continue continue
# 获取因子值 score = float(np.dot(factor_values, weights))
factor_values = []
for factor in factors:
value = broker.fetch_latest_factor(ts_code, factor, eval_date)
if value is None:
skipped_count += 1
break
factor_values.append(value)
# 检查是否所有因子值都已获取
if len(factor_values) != len(factors):
skipped_count += 1
continue
# 计算综合评分
score = np.dot(factor_values, weights)
evaluated_count += 1 evaluated_count += 1
results.append({ results.append({
"股票代码": ts_code, "股票代码": ts_code,
"股票名称": info.get("name", ""), "股票名称": info.get("name", ""),
"行业": info.get("industry", ""), "行业": info.get("industry", ""),
"综合评分": f"{score:.4f}" "综合评分": score,
}) })
# 记录评估完成信息 # 记录评估完成信息
@ -372,36 +422,51 @@ def _calculate_stock_scores(
def _add_to_stock_pool( def _add_to_stock_pool(
ts_codes: List[str], score_df: pd.DataFrame,
eval_date: datetime.date eval_date: date
) -> None: ) -> None:
"""将股票添加到股票池。""" """将股票评分结果写入投资池。"""
with db_session() as session:
# 删除已有记录
session.execute(
"""
DELETE FROM stock_pool
WHERE entry_date = :entry_date
""",
{"entry_date": eval_date}
)
# 插入新记录 trade_date = eval_date.strftime("%Y%m%d")
values = [ payload: List[tuple] = []
ranked_df = score_df.reset_index(drop=True)
for rank, row in ranked_df.iterrows():
tags = json.dumps(["stock_evaluation", "top20"], ensure_ascii=False)
metadata = json.dumps(
{ {
"ts_code": code, "source": "stock_evaluation",
"entry_date": eval_date, "rank": rank + 1,
"entry_reason": "factor_evaluation" "score": float(row["综合评分"]),
} },
for code in ts_codes ensure_ascii=False,
] )
payload.append(
session.execute( (
""" trade_date,
INSERT INTO stock_pool (ts_code, entry_date, entry_reason) row["股票代码"],
VALUES (:ts_code, :entry_date, :entry_reason) float(row["综合评分"]),
""", "candidate",
values "factor_evaluation_top20",
tags,
metadata,
)
) )
session.commit() with db_session() as conn:
conn.execute("DELETE FROM investment_pool WHERE trade_date = ?", (trade_date,))
if payload:
conn.executemany(
"""
INSERT INTO investment_pool (
trade_date,
ts_code,
score,
status,
rationale,
tags,
metadata
) VALUES (?, ?, ?, ?, ?, ?, ?)
""",
payload,
)