enhance factor evaluation with async progress tracking and UI updates
This commit is contained in:
parent
4e7a56567b
commit
07c76d7674
@ -80,14 +80,37 @@ def evaluate_factor(
|
|||||||
"""
|
"""
|
||||||
performance = FactorPerformance(factor_name)
|
performance = FactorPerformance(factor_name)
|
||||||
|
|
||||||
# 计算因子值
|
# 导入进度状态模块
|
||||||
factor_results = compute_factor_range(
|
from app.ui.progress_state import factor_progress
|
||||||
start_date,
|
|
||||||
end_date,
|
# 开始因子计算进度
|
||||||
factors=[FactorSpec(factor_name, 0)],
|
factor_progress.start_calculation(
|
||||||
ts_codes=universe
|
total_securities=len(universe) if universe else 0,
|
||||||
|
message=f"开始评估因子 {factor_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 计算因子值
|
||||||
|
factor_results = compute_factor_range(
|
||||||
|
start_date,
|
||||||
|
end_date,
|
||||||
|
factors=[FactorSpec(factor_name, 0)],
|
||||||
|
ts_codes=universe
|
||||||
|
)
|
||||||
|
|
||||||
|
# 因子计算完成
|
||||||
|
factor_progress.complete_calculation(
|
||||||
|
message=f"因子 {factor_name} 评估完成"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 因子计算失败
|
||||||
|
factor_progress.complete_calculation(
|
||||||
|
message=f"因子 {factor_name} 评估失败: {str(e)}",
|
||||||
|
success=False
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
# 按日期分组
|
# 按日期分组
|
||||||
date_groups: Dict[date, List[FactorResult]] = {}
|
date_groups: Dict[date, List[FactorResult]] = {}
|
||||||
for result in factor_results:
|
for result in factor_results:
|
||||||
|
|||||||
@ -469,51 +469,45 @@ class ExtendedFactors:
|
|||||||
raise ValueError(f"因子 {factor_name} 没有对应的计算实现")
|
raise ValueError(f"因子 {factor_name} 没有对应的计算实现")
|
||||||
|
|
||||||
def compute_all_factors(self,
|
def compute_all_factors(self,
|
||||||
close_series: Sequence[float],
|
close_series: List[float],
|
||||||
volume_series: Sequence[float]) -> Dict[str, float]:
|
volume_series: List[float],
|
||||||
"""计算所有已注册的扩展因子值
|
ts_code: str,
|
||||||
|
trade_date: str) -> Dict[str, float | None]:
|
||||||
|
"""计算所有扩展因子
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
close_series: 收盘价序列,从新到旧排序
|
close_series: 收盘价序列
|
||||||
volume_series: 成交量序列,从新到旧排序
|
volume_series: 成交量序列
|
||||||
|
ts_code: 股票代码
|
||||||
|
trade_date: 交易日期
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, float]: 因子名称到因子值的映射字典,
|
因子名称到因子值的映射
|
||||||
只包含成功计算的因子值
|
|
||||||
|
|
||||||
Note:
|
|
||||||
该方法会尝试计算所有已注册的因子,失败的因子将被忽略。
|
|
||||||
如果所有因子计算都失败,将返回空字典。
|
|
||||||
"""
|
"""
|
||||||
results = {}
|
results = {}
|
||||||
success_count = 0
|
|
||||||
total_count = len(self.factor_specs)
|
|
||||||
|
|
||||||
for factor_name in self.factor_specs:
|
for factor_spec in EXTENDED_FACTORS:
|
||||||
value = self.compute_factor(factor_name, close_series, volume_series)
|
try:
|
||||||
if value is not None:
|
factor_name = factor_spec.name
|
||||||
# 验证因子值是否在合理范围内
|
factor_value = self.compute_factor(factor_name, close_series, volume_series)
|
||||||
validated_value = validate_factor_value(
|
|
||||||
factor_name, value, "unknown", "unknown"
|
# 验证因子值
|
||||||
|
if factor_value is not None:
|
||||||
|
# 使用真实的 ts_code 和 trade_date 进行验证
|
||||||
|
validate_factor_value(factor_name, factor_value, ts_code, trade_date)
|
||||||
|
|
||||||
|
results[factor_name] = factor_value
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.debug(
|
||||||
|
"因子计算失败 factor=%s ts_code=%s date=%s err=%s",
|
||||||
|
factor_spec.name,
|
||||||
|
ts_code,
|
||||||
|
trade_date,
|
||||||
|
str(e),
|
||||||
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
if validated_value is not None:
|
results[factor_spec.name] = None
|
||||||
results[factor_name] = validated_value
|
|
||||||
success_count += 1
|
|
||||||
else:
|
|
||||||
LOGGER.debug(
|
|
||||||
"因子值验证失败 factor=%s value=%f",
|
|
||||||
factor_name, value,
|
|
||||||
extra=LOG_EXTRA
|
|
||||||
)
|
|
||||||
|
|
||||||
# 关闭因子计算完成日志打印
|
|
||||||
# LOGGER.info(
|
|
||||||
# "因子计算完成 total=%d success=%d failed=%d",
|
|
||||||
# total_count,
|
|
||||||
# success_count,
|
|
||||||
# total_count - success_count,
|
|
||||||
# extra=LOG_EXTRA
|
|
||||||
# )
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
@ -332,14 +332,12 @@ def _compute_batch_factors(
|
|||||||
# 批次化数据可用性检查
|
# 批次化数据可用性检查
|
||||||
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
|
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
|
||||||
|
|
||||||
# 更新UI进度状态
|
# 更新UI进度状态 - 开始处理批次
|
||||||
if total_securities > 0:
|
if total_securities > 0:
|
||||||
current_progress = processed_securities + len(available_codes)
|
|
||||||
progress_percentage = (current_progress / total_securities) * 100
|
|
||||||
factor_progress.update_progress(
|
factor_progress.update_progress(
|
||||||
current_securities=current_progress,
|
current_securities=processed_securities,
|
||||||
current_batch=batch_index + 1,
|
current_batch=batch_index + 1,
|
||||||
message=f"处理批次 {batch_index + 1}/{total_batches} - 证券 {current_progress}/{total_securities} ({progress_percentage:.1f}%)"
|
message=f"开始处理批次 {batch_index + 1}/{total_batches}"
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, ts_code in enumerate(ts_codes):
|
for i, ts_code in enumerate(ts_codes):
|
||||||
@ -379,8 +377,8 @@ def _compute_batch_factors(
|
|||||||
else:
|
else:
|
||||||
validation_stats["skipped"] += 1
|
validation_stats["skipped"] += 1
|
||||||
|
|
||||||
# 每处理10个证券更新一次进度
|
# 每处理1个证券更新一次进度,确保实时性
|
||||||
if (i + 1) % 10 == 0 and total_securities > 0:
|
if total_securities > 0:
|
||||||
current_progress = processed_securities + i + 1
|
current_progress = processed_securities + i + 1
|
||||||
progress_percentage = (current_progress / total_securities) * 100
|
progress_percentage = (current_progress / total_securities) * 100
|
||||||
factor_progress.update_progress(
|
factor_progress.update_progress(
|
||||||
@ -397,6 +395,16 @@ def _compute_batch_factors(
|
|||||||
)
|
)
|
||||||
validation_stats["skipped"] += 1
|
validation_stats["skipped"] += 1
|
||||||
|
|
||||||
|
# 批次处理完成,更新最终进度
|
||||||
|
if total_securities > 0:
|
||||||
|
final_progress = processed_securities + len(ts_codes)
|
||||||
|
progress_percentage = (final_progress / total_securities) * 100
|
||||||
|
factor_progress.update_progress(
|
||||||
|
current_securities=final_progress,
|
||||||
|
current_batch=batch_index + 1,
|
||||||
|
message=f"批次 {batch_index + 1}/{total_batches} 处理完成 - 证券 {final_progress}/{total_securities} ({progress_percentage:.1f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
return batch_results
|
return batch_results
|
||||||
|
|
||||||
|
|
||||||
@ -676,7 +684,7 @@ def _compute_security_factors(
|
|||||||
|
|
||||||
# 计算扩展因子值
|
# 计算扩展因子值
|
||||||
calculator = ExtendedFactors()
|
calculator = ExtendedFactors()
|
||||||
extended_factors = calculator.compute_all_factors(close_series, volume_series)
|
extended_factors = calculator.compute_all_factors(close_series, volume_series, ts_code, trade_date)
|
||||||
results.update(extended_factors)
|
results.update(extended_factors)
|
||||||
|
|
||||||
# 计算情感因子
|
# 计算情感因子
|
||||||
|
|||||||
@ -10,7 +10,12 @@ class FactorProgressState:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""初始化进度状态"""
|
"""初始化进度状态"""
|
||||||
if 'factor_progress' not in st.session_state:
|
# 确保session_state中有factor_progress属性
|
||||||
|
self._ensure_initialized()
|
||||||
|
|
||||||
|
def _ensure_initialized(self) -> None:
|
||||||
|
"""确保进度状态已初始化"""
|
||||||
|
if not hasattr(st.session_state, 'factor_progress'):
|
||||||
st.session_state.factor_progress = {
|
st.session_state.factor_progress = {
|
||||||
'current': 0,
|
'current': 0,
|
||||||
'total': 0,
|
'total': 0,
|
||||||
@ -99,6 +104,7 @@ class FactorProgressState:
|
|||||||
Returns:
|
Returns:
|
||||||
进度信息字典
|
进度信息字典
|
||||||
"""
|
"""
|
||||||
|
self._ensure_initialized()
|
||||||
return st.session_state.factor_progress.copy()
|
return st.session_state.factor_progress.copy()
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
|||||||
@ -67,9 +67,6 @@ def main() -> None:
|
|||||||
|
|
||||||
render_global_dashboard()
|
render_global_dashboard()
|
||||||
|
|
||||||
# 显示因子计算进度
|
|
||||||
render_factor_progress()
|
|
||||||
|
|
||||||
tabs = st.tabs(["今日计划", "投资池/仓位", "回测与复盘", "行情可视化", "日志钻取", "数据与设置", "自检测试"])
|
tabs = st.tabs(["今日计划", "投资池/仓位", "回测与复盘", "行情可视化", "日志钻取", "数据与设置", "自检测试"])
|
||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"Tabs 初始化完成:%s",
|
"Tabs 初始化完成:%s",
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
"""股票筛选与评估视图。"""
|
"""股票筛选与评估视图。"""
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -136,10 +138,32 @@ def render_stock_evaluation() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 4. 评估结果
|
# 4. 评估结果
|
||||||
if st.button("开始评估", disabled=not selected_factors):
|
|
||||||
with st.spinner("正在评估因子表现..."):
|
# 初始化会话状态
|
||||||
|
if 'evaluation_thread' not in st.session_state:
|
||||||
|
st.session_state.evaluation_thread = None
|
||||||
|
if 'evaluation_results' not in st.session_state:
|
||||||
|
st.session_state.evaluation_results = None
|
||||||
|
if 'evaluation_status' not in st.session_state:
|
||||||
|
st.session_state.evaluation_status = 'idle' # idle, running, completed, error
|
||||||
|
if 'current_factor' not in st.session_state:
|
||||||
|
st.session_state.current_factor = ''
|
||||||
|
if 'progress' not in st.session_state:
|
||||||
|
st.session_state.progress = 0
|
||||||
|
|
||||||
|
# 异步评估函数
|
||||||
|
def run_evaluation_async():
|
||||||
|
try:
|
||||||
|
st.session_state.evaluation_status = 'running'
|
||||||
results = []
|
results = []
|
||||||
for factor_name in selected_factors:
|
|
||||||
|
for i, factor_name in enumerate(selected_factors):
|
||||||
|
st.session_state.current_factor = factor_name
|
||||||
|
st.session_state.progress = (i / len(selected_factors)) * 100
|
||||||
|
|
||||||
|
# 模拟进度更新(实际计算中进度会在evaluate_factor内部更新)
|
||||||
|
time.sleep(0.1) # 让UI有机会更新
|
||||||
|
|
||||||
performance = evaluate_factor(
|
performance = evaluate_factor(
|
||||||
factor_name,
|
factor_name,
|
||||||
start_date,
|
start_date,
|
||||||
@ -155,51 +179,87 @@ def render_stock_evaluation() -> None:
|
|||||||
"换手率": f"{performance.turnover_rate*100:.1f}%" if performance.turnover_rate else "N/A"
|
"换手率": f"{performance.turnover_rate*100:.1f}%" if performance.turnover_rate else "N/A"
|
||||||
})
|
})
|
||||||
|
|
||||||
if results:
|
st.session_state.evaluation_results = results
|
||||||
st.markdown("##### 因子评估结果")
|
st.session_state.evaluation_status = 'completed'
|
||||||
result_df = pd.DataFrame(results)
|
st.session_state.progress = 100
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
st.session_state.evaluation_status = 'error'
|
||||||
|
st.session_state.evaluation_error = str(e)
|
||||||
|
|
||||||
|
# 显示进度
|
||||||
|
if st.session_state.evaluation_status == 'running':
|
||||||
|
st.info(f"正在评估因子: {st.session_state.current_factor}")
|
||||||
|
st.progress(st.session_state.progress / 100)
|
||||||
|
elif st.session_state.evaluation_status == 'completed':
|
||||||
|
st.success("因子评估完成!")
|
||||||
|
elif st.session_state.evaluation_status == 'error':
|
||||||
|
st.error(f"评估失败: {st.session_state.evaluation_error}")
|
||||||
|
|
||||||
|
# 开始评估按钮
|
||||||
|
if st.button("开始评估", disabled=not selected_factors or st.session_state.evaluation_status == 'running'):
|
||||||
|
# 重置状态
|
||||||
|
st.session_state.evaluation_results = None
|
||||||
|
st.session_state.evaluation_status = 'running'
|
||||||
|
st.session_state.progress = 0
|
||||||
|
|
||||||
|
# 启动异步线程
|
||||||
|
thread = threading.Thread(target=run_evaluation_async)
|
||||||
|
thread.daemon = True
|
||||||
|
thread.start()
|
||||||
|
st.session_state.evaluation_thread = thread
|
||||||
|
|
||||||
|
# 强制重新运行以显示进度
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# 显示结果
|
||||||
|
if st.session_state.evaluation_results:
|
||||||
|
results = st.session_state.evaluation_results
|
||||||
|
|
||||||
|
st.markdown("##### 因子评估结果")
|
||||||
|
result_df = pd.DataFrame(results)
|
||||||
|
st.dataframe(
|
||||||
|
result_df,
|
||||||
|
hide_index=True,
|
||||||
|
use_container_width=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 绘制IC均值分布
|
||||||
|
ic_means = [float(r["IC均值"]) for r in results]
|
||||||
|
chart_df = pd.DataFrame({
|
||||||
|
"因子": [r["因子"] for r in results],
|
||||||
|
"IC均值": ic_means
|
||||||
|
})
|
||||||
|
st.bar_chart(chart_df.set_index("因子"))
|
||||||
|
|
||||||
|
# 生成股票评分
|
||||||
|
with st.spinner("正在生成股票评分..."):
|
||||||
|
scores = _calculate_stock_scores(
|
||||||
|
universe,
|
||||||
|
selected_factors,
|
||||||
|
end_date,
|
||||||
|
ic_means
|
||||||
|
)
|
||||||
|
|
||||||
|
if scores:
|
||||||
|
st.markdown("##### 股票综合评分 (Top 20)")
|
||||||
|
score_df = pd.DataFrame(scores).sort_values(
|
||||||
|
"综合评分",
|
||||||
|
ascending=False
|
||||||
|
).head(20)
|
||||||
st.dataframe(
|
st.dataframe(
|
||||||
result_df,
|
score_df,
|
||||||
hide_index=True,
|
hide_index=True,
|
||||||
use_container_width=True
|
use_container_width=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# 绘制IC均值分布
|
# 添加入池功能
|
||||||
ic_means = [float(r["IC均值"]) for r in results]
|
if st.button("将Top 20股票加入股票池"):
|
||||||
chart_df = pd.DataFrame({
|
_add_to_stock_pool(
|
||||||
"因子": [r["因子"] for r in results],
|
score_df["股票代码"].tolist(),
|
||||||
"IC均值": ic_means
|
end_date
|
||||||
})
|
|
||||||
st.bar_chart(chart_df.set_index("因子"))
|
|
||||||
|
|
||||||
# 生成股票评分
|
|
||||||
with st.spinner("正在生成股票评分..."):
|
|
||||||
scores = _calculate_stock_scores(
|
|
||||||
universe,
|
|
||||||
selected_factors,
|
|
||||||
end_date,
|
|
||||||
ic_means
|
|
||||||
)
|
)
|
||||||
|
st.success("已成功将选中股票加入股票池!")
|
||||||
if scores:
|
|
||||||
st.markdown("##### 股票综合评分 (Top 20)")
|
|
||||||
score_df = pd.DataFrame(scores).sort_values(
|
|
||||||
"综合评分",
|
|
||||||
ascending=False
|
|
||||||
).head(20)
|
|
||||||
st.dataframe(
|
|
||||||
score_df,
|
|
||||||
hide_index=True,
|
|
||||||
use_container_width=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加入池功能
|
|
||||||
if st.button("将Top 20股票加入股票池"):
|
|
||||||
_add_to_stock_pool(
|
|
||||||
score_df["股票代码"].tolist(),
|
|
||||||
end_date
|
|
||||||
)
|
|
||||||
st.success("已成功将选中股票加入股票池!")
|
|
||||||
|
|
||||||
|
|
||||||
def _calculate_stock_scores(
|
def _calculate_stock_scores(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user