enhance factor evaluation with async progress tracking and UI updates
This commit is contained in:
parent
4e7a56567b
commit
07c76d7674
@ -80,6 +80,16 @@ def evaluate_factor(
|
||||
"""
|
||||
performance = FactorPerformance(factor_name)
|
||||
|
||||
# 导入进度状态模块
|
||||
from app.ui.progress_state import factor_progress
|
||||
|
||||
# 开始因子计算进度
|
||||
factor_progress.start_calculation(
|
||||
total_securities=len(universe) if universe else 0,
|
||||
message=f"开始评估因子 {factor_name}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 计算因子值
|
||||
factor_results = compute_factor_range(
|
||||
start_date,
|
||||
@ -88,6 +98,19 @@ def evaluate_factor(
|
||||
ts_codes=universe
|
||||
)
|
||||
|
||||
# 因子计算完成
|
||||
factor_progress.complete_calculation(
|
||||
message=f"因子 {factor_name} 评估完成"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 因子计算失败
|
||||
factor_progress.complete_calculation(
|
||||
message=f"因子 {factor_name} 评估失败: {str(e)}",
|
||||
success=False
|
||||
)
|
||||
raise
|
||||
|
||||
# 按日期分组
|
||||
date_groups: Dict[date, List[FactorResult]] = {}
|
||||
for result in factor_results:
|
||||
|
||||
@ -469,51 +469,45 @@ class ExtendedFactors:
|
||||
raise ValueError(f"因子 {factor_name} 没有对应的计算实现")
|
||||
|
||||
def compute_all_factors(self,
|
||||
close_series: Sequence[float],
|
||||
volume_series: Sequence[float]) -> Dict[str, float]:
|
||||
"""计算所有已注册的扩展因子值
|
||||
close_series: List[float],
|
||||
volume_series: List[float],
|
||||
ts_code: str,
|
||||
trade_date: str) -> Dict[str, float | None]:
|
||||
"""计算所有扩展因子
|
||||
|
||||
Args:
|
||||
close_series: 收盘价序列,从新到旧排序
|
||||
volume_series: 成交量序列,从新到旧排序
|
||||
close_series: 收盘价序列
|
||||
volume_series: 成交量序列
|
||||
ts_code: 股票代码
|
||||
trade_date: 交易日期
|
||||
|
||||
Returns:
|
||||
Dict[str, float]: 因子名称到因子值的映射字典,
|
||||
只包含成功计算的因子值
|
||||
|
||||
Note:
|
||||
该方法会尝试计算所有已注册的因子,失败的因子将被忽略。
|
||||
如果所有因子计算都失败,将返回空字典。
|
||||
因子名称到因子值的映射
|
||||
"""
|
||||
results = {}
|
||||
success_count = 0
|
||||
total_count = len(self.factor_specs)
|
||||
|
||||
for factor_name in self.factor_specs:
|
||||
value = self.compute_factor(factor_name, close_series, volume_series)
|
||||
if value is not None:
|
||||
# 验证因子值是否在合理范围内
|
||||
validated_value = validate_factor_value(
|
||||
factor_name, value, "unknown", "unknown"
|
||||
)
|
||||
if validated_value is not None:
|
||||
results[factor_name] = validated_value
|
||||
success_count += 1
|
||||
else:
|
||||
for factor_spec in EXTENDED_FACTORS:
|
||||
try:
|
||||
factor_name = factor_spec.name
|
||||
factor_value = self.compute_factor(factor_name, close_series, volume_series)
|
||||
|
||||
# 验证因子值
|
||||
if factor_value is not None:
|
||||
# 使用真实的 ts_code 和 trade_date 进行验证
|
||||
validate_factor_value(factor_name, factor_value, ts_code, trade_date)
|
||||
|
||||
results[factor_name] = factor_value
|
||||
|
||||
except Exception as e:
|
||||
LOGGER.debug(
|
||||
"因子值验证失败 factor=%s value=%f",
|
||||
factor_name, value,
|
||||
extra=LOG_EXTRA
|
||||
"因子计算失败 factor=%s ts_code=%s date=%s err=%s",
|
||||
factor_spec.name,
|
||||
ts_code,
|
||||
trade_date,
|
||||
str(e),
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
# 关闭因子计算完成日志打印
|
||||
# LOGGER.info(
|
||||
# "因子计算完成 total=%d success=%d failed=%d",
|
||||
# total_count,
|
||||
# success_count,
|
||||
# total_count - success_count,
|
||||
# extra=LOG_EXTRA
|
||||
# )
|
||||
results[factor_spec.name] = None
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@ -332,14 +332,12 @@ def _compute_batch_factors(
|
||||
# 批次化数据可用性检查
|
||||
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
|
||||
|
||||
# 更新UI进度状态
|
||||
# 更新UI进度状态 - 开始处理批次
|
||||
if total_securities > 0:
|
||||
current_progress = processed_securities + len(available_codes)
|
||||
progress_percentage = (current_progress / total_securities) * 100
|
||||
factor_progress.update_progress(
|
||||
current_securities=current_progress,
|
||||
current_securities=processed_securities,
|
||||
current_batch=batch_index + 1,
|
||||
message=f"处理批次 {batch_index + 1}/{total_batches} - 证券 {current_progress}/{total_securities} ({progress_percentage:.1f}%)"
|
||||
message=f"开始处理批次 {batch_index + 1}/{total_batches}"
|
||||
)
|
||||
|
||||
for i, ts_code in enumerate(ts_codes):
|
||||
@ -379,8 +377,8 @@ def _compute_batch_factors(
|
||||
else:
|
||||
validation_stats["skipped"] += 1
|
||||
|
||||
# 每处理10个证券更新一次进度
|
||||
if (i + 1) % 10 == 0 and total_securities > 0:
|
||||
# 每处理1个证券更新一次进度,确保实时性
|
||||
if total_securities > 0:
|
||||
current_progress = processed_securities + i + 1
|
||||
progress_percentage = (current_progress / total_securities) * 100
|
||||
factor_progress.update_progress(
|
||||
@ -397,6 +395,16 @@ def _compute_batch_factors(
|
||||
)
|
||||
validation_stats["skipped"] += 1
|
||||
|
||||
# 批次处理完成,更新最终进度
|
||||
if total_securities > 0:
|
||||
final_progress = processed_securities + len(ts_codes)
|
||||
progress_percentage = (final_progress / total_securities) * 100
|
||||
factor_progress.update_progress(
|
||||
current_securities=final_progress,
|
||||
current_batch=batch_index + 1,
|
||||
message=f"批次 {batch_index + 1}/{total_batches} 处理完成 - 证券 {final_progress}/{total_securities} ({progress_percentage:.1f}%)"
|
||||
)
|
||||
|
||||
return batch_results
|
||||
|
||||
|
||||
@ -676,7 +684,7 @@ def _compute_security_factors(
|
||||
|
||||
# 计算扩展因子值
|
||||
calculator = ExtendedFactors()
|
||||
extended_factors = calculator.compute_all_factors(close_series, volume_series)
|
||||
extended_factors = calculator.compute_all_factors(close_series, volume_series, ts_code, trade_date)
|
||||
results.update(extended_factors)
|
||||
|
||||
# 计算情感因子
|
||||
|
||||
@ -10,7 +10,12 @@ class FactorProgressState:
|
||||
|
||||
def __init__(self):
|
||||
"""初始化进度状态"""
|
||||
if 'factor_progress' not in st.session_state:
|
||||
# 确保session_state中有factor_progress属性
|
||||
self._ensure_initialized()
|
||||
|
||||
def _ensure_initialized(self) -> None:
|
||||
"""确保进度状态已初始化"""
|
||||
if not hasattr(st.session_state, 'factor_progress'):
|
||||
st.session_state.factor_progress = {
|
||||
'current': 0,
|
||||
'total': 0,
|
||||
@ -99,6 +104,7 @@ class FactorProgressState:
|
||||
Returns:
|
||||
进度信息字典
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
return st.session_state.factor_progress.copy()
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@ -67,9 +67,6 @@ def main() -> None:
|
||||
|
||||
render_global_dashboard()
|
||||
|
||||
# 显示因子计算进度
|
||||
render_factor_progress()
|
||||
|
||||
tabs = st.tabs(["今日计划", "投资池/仓位", "回测与复盘", "行情可视化", "日志钻取", "数据与设置", "自检测试"])
|
||||
LOGGER.debug(
|
||||
"Tabs 初始化完成:%s",
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
"""股票筛选与评估视图。"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -136,10 +138,32 @@ def render_stock_evaluation() -> None:
|
||||
)
|
||||
|
||||
# 4. 评估结果
|
||||
if st.button("开始评估", disabled=not selected_factors):
|
||||
with st.spinner("正在评估因子表现..."):
|
||||
|
||||
# 初始化会话状态
|
||||
if 'evaluation_thread' not in st.session_state:
|
||||
st.session_state.evaluation_thread = None
|
||||
if 'evaluation_results' not in st.session_state:
|
||||
st.session_state.evaluation_results = None
|
||||
if 'evaluation_status' not in st.session_state:
|
||||
st.session_state.evaluation_status = 'idle' # idle, running, completed, error
|
||||
if 'current_factor' not in st.session_state:
|
||||
st.session_state.current_factor = ''
|
||||
if 'progress' not in st.session_state:
|
||||
st.session_state.progress = 0
|
||||
|
||||
# 异步评估函数
|
||||
def run_evaluation_async():
|
||||
try:
|
||||
st.session_state.evaluation_status = 'running'
|
||||
results = []
|
||||
for factor_name in selected_factors:
|
||||
|
||||
for i, factor_name in enumerate(selected_factors):
|
||||
st.session_state.current_factor = factor_name
|
||||
st.session_state.progress = (i / len(selected_factors)) * 100
|
||||
|
||||
# 模拟进度更新(实际计算中进度会在evaluate_factor内部更新)
|
||||
time.sleep(0.1) # 让UI有机会更新
|
||||
|
||||
performance = evaluate_factor(
|
||||
factor_name,
|
||||
start_date,
|
||||
@ -155,7 +179,43 @@ def render_stock_evaluation() -> None:
|
||||
"换手率": f"{performance.turnover_rate*100:.1f}%" if performance.turnover_rate else "N/A"
|
||||
})
|
||||
|
||||
if results:
|
||||
st.session_state.evaluation_results = results
|
||||
st.session_state.evaluation_status = 'completed'
|
||||
st.session_state.progress = 100
|
||||
|
||||
except Exception as e:
|
||||
st.session_state.evaluation_status = 'error'
|
||||
st.session_state.evaluation_error = str(e)
|
||||
|
||||
# 显示进度
|
||||
if st.session_state.evaluation_status == 'running':
|
||||
st.info(f"正在评估因子: {st.session_state.current_factor}")
|
||||
st.progress(st.session_state.progress / 100)
|
||||
elif st.session_state.evaluation_status == 'completed':
|
||||
st.success("因子评估完成!")
|
||||
elif st.session_state.evaluation_status == 'error':
|
||||
st.error(f"评估失败: {st.session_state.evaluation_error}")
|
||||
|
||||
# 开始评估按钮
|
||||
if st.button("开始评估", disabled=not selected_factors or st.session_state.evaluation_status == 'running'):
|
||||
# 重置状态
|
||||
st.session_state.evaluation_results = None
|
||||
st.session_state.evaluation_status = 'running'
|
||||
st.session_state.progress = 0
|
||||
|
||||
# 启动异步线程
|
||||
thread = threading.Thread(target=run_evaluation_async)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
st.session_state.evaluation_thread = thread
|
||||
|
||||
# 强制重新运行以显示进度
|
||||
st.rerun()
|
||||
|
||||
# 显示结果
|
||||
if st.session_state.evaluation_results:
|
||||
results = st.session_state.evaluation_results
|
||||
|
||||
st.markdown("##### 因子评估结果")
|
||||
result_df = pd.DataFrame(results)
|
||||
st.dataframe(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user