refactor factor calculation to synchronous mode with progress tracking

This commit is contained in:
sam 2025-10-08 21:37:40 +08:00
parent 974fc90fc3
commit db0afe9c2d
3 changed files with 45 additions and 115 deletions

View File

@ -332,13 +332,14 @@ def _compute_batch_factors(
# 批次化数据可用性检查
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
# 更新UI进度状态 - 开始处理批次在异步线程中不直接访问factor_progress
# if total_securities > 0:
# factor_progress.update_progress(
# current_securities=processed_securities,
# current_batch=batch_index + 1,
# message=f"开始处理批次 {batch_index + 1}/{total_batches}"
# )
# 更新UI进度状态 - 开始处理批次
if total_securities > 0:
from app.ui.progress_state import factor_progress
factor_progress.update_progress(
current_securities=processed_securities,
current_batch=batch_index + 1,
message=f"开始处理批次 {batch_index + 1}/{total_batches}"
)
for i, ts_code in enumerate(ts_codes):
try:
@ -381,6 +382,7 @@ def _compute_batch_factors(
if total_securities > 0:
current_progress = processed_securities + i + 1
progress_percentage = (current_progress / total_securities) * 100
from app.ui.progress_state import factor_progress
factor_progress.update_progress(
current_securities=current_progress,
current_batch=batch_index + 1,
@ -399,6 +401,7 @@ def _compute_batch_factors(
if total_securities > 0:
final_progress = processed_securities + len(ts_codes)
progress_percentage = (final_progress / total_securities) * 100
from app.ui.progress_state import factor_progress
factor_progress.update_progress(
current_securities=final_progress,
current_batch=batch_index + 1,

View File

@ -1,8 +1,6 @@
"""因子计算页面。"""
from datetime import datetime, timedelta
from typing import List, Optional
import threading
import time
import streamlit as st
@ -155,12 +153,9 @@ def render_factor_calculation() -> None:
help="如果勾选,将跳过数据库中已存在的因子计算结果"
)
# 5. 异步计算函数
def run_factor_calculation_async():
"""异步执行因子计算"""
# 在异步线程中避免直接访问st.session_state
# 使用全局变量或文件来传递进度信息
# 5. 同步计算函数
def run_factor_calculation_sync():
"""同步执行因子计算"""
# 计算参数
total_stocks = len(universe) if universe else len(_get_all_stocks())
total_batches = len(selected_factors)
@ -169,6 +164,13 @@ def render_factor_calculation() -> None:
# 执行因子计算
results = []
for i, factor in enumerate(selected_factors):
# 更新批次进度
factor_progress.update_progress(
current_securities=0,
current_batch=i+1,
message=f"正在计算因子: {factor.name}"
)
# 计算单个交易日的因子
current_date = start_date
while current_date <= end_date:
@ -188,55 +190,31 @@ def render_factor_calculation() -> None:
print(f"ERROR: {error_msg}")
current_date += timedelta(days=1)
# 短暂暂停
time.sleep(0.1)
# 计算完成,通过文件或全局变量传递结果
# 这里使用简单的文件方式传递结果
import json
import tempfile
import os
# 计算完成
factor_progress.complete_calculation(f"因子计算完成!共计算 {len(results)} 条因子记录")
# 创建临时文件存储结果
temp_dir = tempfile.gettempdir()
result_file = os.path.join(temp_dir, f"factor_calculation_{threading.get_ident()}.json")
result_data = {
return {
'success': True,
'results': [r.to_dict() if hasattr(r, 'to_dict') else str(r) for r in results],
'results': results,
'factors': [f.name for f in selected_factors],
'date_range': f"{start_date}{end_date}",
'stock_count': len(set(r.ts_code for r in results)) if results else 0,
'message': f"因子计算完成!共计算 {len(results)} 条因子记录"
}
with open(result_file, 'w', encoding='utf-8') as f:
json.dump(result_data, f, ensure_ascii=False, indent=2)
except Exception as e:
# 计算失败
import json
import tempfile
import os
temp_dir = tempfile.gettempdir()
result_file = os.path.join(temp_dir, f"factor_calculation_{threading.get_ident()}.json")
error_data = {
factor_progress.error_occurred(f"因子计算失败: {str(e)}")
return {
'success': False,
'error': str(e),
'message': f"因子计算失败: {str(e)}"
}
with open(result_file, 'w', encoding='utf-8') as f:
json.dump(error_data, f, ensure_ascii=False, indent=2)
# 6. 开始计算按钮
if st.button("开始计算因子", disabled=not selected_factors):
# 重置状态
if 'factor_calculation_thread' in st.session_state:
st.session_state.factor_calculation_thread = None
if 'factor_calculation_results' in st.session_state:
st.session_state.factor_calculation_results = None
if 'factor_calculation_error' in st.session_state:
@ -249,18 +227,21 @@ def render_factor_calculation() -> None:
total_batches=len(selected_factors)
)
# 启动异步线程
thread = threading.Thread(target=run_factor_calculation_async)
thread.daemon = True
thread.start()
st.session_state.factor_calculation_thread = thread
st.session_state.factor_calculation_thread_id = thread.ident
# 直接调用同步计算函数
result = run_factor_calculation_sync()
# 显示计算中状态
st.info("因子计算已开始,请查看侧边栏进度显示...")
# 强制重新运行以显示进度
st.rerun()
# 处理计算结果
if result['success']:
st.session_state.factor_calculation_results = {
'results': result['results'],
'factors': result['factors'],
'date_range': result['date_range'],
'stock_count': result['stock_count']
}
st.success("✅ 因子计算完成!")
else:
st.session_state.factor_calculation_error = result['error']
st.error(f"❌ 因子计算失败: {result['error']}")
# 7. 显示计算结果
if 'factor_calculation_results' in st.session_state and st.session_state.factor_calculation_results:
@ -298,48 +279,7 @@ def render_factor_calculation() -> None:
else:
st.info("没有找到因子计算结果")
# 8. 检查异步线程结果
if 'factor_calculation_thread_id' in st.session_state:
import json
import tempfile
import os
thread_id = st.session_state.factor_calculation_thread_id
temp_dir = tempfile.gettempdir()
result_file = os.path.join(temp_dir, f"factor_calculation_{thread_id}.json")
# 检查结果文件是否存在
if os.path.exists(result_file):
try:
with open(result_file, 'r', encoding='utf-8') as f:
result_data = json.load(f)
# 处理结果
if result_data['success']:
# 计算成功
factor_progress.complete_calculation(result_data['message'])
st.session_state.factor_calculation_results = {
'results': result_data['results'],
'factors': result_data['factors'],
'date_range': result_data['date_range'],
'stock_count': result_data['stock_count']
}
else:
# 计算失败
factor_progress.error_occurred(result_data['message'])
st.session_state.factor_calculation_error = result_data['error']
# 清理临时文件
os.remove(result_file)
# 清除线程状态
st.session_state.factor_calculation_thread_id = None
# 强制重新运行以显示结果
st.rerun()
except Exception as e:
st.error(f"处理计算结果时出错: {str(e)}")
# 8. 移除异步线程检查逻辑(已改为同步模式)
# 9. 显示错误信息
if 'factor_calculation_error' in st.session_state and st.session_state.factor_calculation_error:

View File

@ -1,8 +1,6 @@
"""股票筛选与评估视图。"""
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
import threading
import time
import numpy as np
import pandas as pd
@ -140,8 +138,6 @@ def render_stock_evaluation() -> None:
# 4. 评估结果
# 初始化会话状态
if 'evaluation_thread' not in st.session_state:
st.session_state.evaluation_thread = None
if 'evaluation_results' not in st.session_state:
st.session_state.evaluation_results = None
if 'evaluation_status' not in st.session_state:
@ -151,8 +147,8 @@ def render_stock_evaluation() -> None:
if 'progress' not in st.session_state:
st.session_state.progress = 0
# 步评估函数
def run_evaluation_async():
# 步评估函数
def run_evaluation_sync():
try:
st.session_state.evaluation_status = 'running'
results = []
@ -161,9 +157,6 @@ def render_stock_evaluation() -> None:
st.session_state.current_factor = factor_name
st.session_state.progress = (i / len(selected_factors)) * 100
# 模拟进度更新实际计算中进度会在evaluate_factor内部更新
time.sleep(0.1) # 让UI有机会更新
performance = evaluate_factor(
factor_name,
start_date,
@ -203,14 +196,8 @@ def render_stock_evaluation() -> None:
st.session_state.evaluation_status = 'running'
st.session_state.progress = 0
# 启动异步线程
thread = threading.Thread(target=run_evaluation_async)
thread.daemon = True
thread.start()
st.session_state.evaluation_thread = thread
# 强制重新运行以显示进度
st.rerun()
# 直接调用同步评估函数
run_evaluation_sync()
# 显示结果
if st.session_state.evaluation_results: