refactor factor calculation to synchronous mode with progress tracking

This commit is contained in:
sam 2025-10-08 21:37:40 +08:00
parent 974fc90fc3
commit db0afe9c2d
3 changed files with 45 additions and 115 deletions

View File

@ -332,13 +332,14 @@ def _compute_batch_factors(
# 批次化数据可用性检查 # 批次化数据可用性检查
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs) available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
# 更新UI进度状态 - 开始处理批次在异步线程中不直接访问factor_progress # 更新UI进度状态 - 开始处理批次
# if total_securities > 0: if total_securities > 0:
# factor_progress.update_progress( from app.ui.progress_state import factor_progress
# current_securities=processed_securities, factor_progress.update_progress(
# current_batch=batch_index + 1, current_securities=processed_securities,
# message=f"开始处理批次 {batch_index + 1}/{total_batches}" current_batch=batch_index + 1,
# ) message=f"开始处理批次 {batch_index + 1}/{total_batches}"
)
for i, ts_code in enumerate(ts_codes): for i, ts_code in enumerate(ts_codes):
try: try:
@ -381,6 +382,7 @@ def _compute_batch_factors(
if total_securities > 0: if total_securities > 0:
current_progress = processed_securities + i + 1 current_progress = processed_securities + i + 1
progress_percentage = (current_progress / total_securities) * 100 progress_percentage = (current_progress / total_securities) * 100
from app.ui.progress_state import factor_progress
factor_progress.update_progress( factor_progress.update_progress(
current_securities=current_progress, current_securities=current_progress,
current_batch=batch_index + 1, current_batch=batch_index + 1,
@ -399,6 +401,7 @@ def _compute_batch_factors(
if total_securities > 0: if total_securities > 0:
final_progress = processed_securities + len(ts_codes) final_progress = processed_securities + len(ts_codes)
progress_percentage = (final_progress / total_securities) * 100 progress_percentage = (final_progress / total_securities) * 100
from app.ui.progress_state import factor_progress
factor_progress.update_progress( factor_progress.update_progress(
current_securities=final_progress, current_securities=final_progress,
current_batch=batch_index + 1, current_batch=batch_index + 1,

View File

@ -1,8 +1,6 @@
"""因子计算页面。""" """因子计算页面。"""
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Optional from typing import List, Optional
import threading
import time
import streamlit as st import streamlit as st
@ -155,12 +153,9 @@ def render_factor_calculation() -> None:
help="如果勾选,将跳过数据库中已存在的因子计算结果" help="如果勾选,将跳过数据库中已存在的因子计算结果"
) )
# 5. 异步计算函数 # 5. 同步计算函数
def run_factor_calculation_async(): def run_factor_calculation_sync():
"""异步执行因子计算""" """同步执行因子计算"""
# 在异步线程中避免直接访问st.session_state
# 使用全局变量或文件来传递进度信息
# 计算参数 # 计算参数
total_stocks = len(universe) if universe else len(_get_all_stocks()) total_stocks = len(universe) if universe else len(_get_all_stocks())
total_batches = len(selected_factors) total_batches = len(selected_factors)
@ -169,6 +164,13 @@ def render_factor_calculation() -> None:
# 执行因子计算 # 执行因子计算
results = [] results = []
for i, factor in enumerate(selected_factors): for i, factor in enumerate(selected_factors):
# 更新批次进度
factor_progress.update_progress(
current_securities=0,
current_batch=i+1,
message=f"正在计算因子: {factor.name}"
)
# 计算单个交易日的因子 # 计算单个交易日的因子
current_date = start_date current_date = start_date
while current_date <= end_date: while current_date <= end_date:
@ -189,54 +191,30 @@ def render_factor_calculation() -> None:
current_date += timedelta(days=1) current_date += timedelta(days=1)
# 短暂暂停 # 计算完成
time.sleep(0.1) factor_progress.complete_calculation(f"因子计算完成!共计算 {len(results)} 条因子记录")
# 计算完成,通过文件或全局变量传递结果 return {
# 这里使用简单的文件方式传递结果
import json
import tempfile
import os
# 创建临时文件存储结果
temp_dir = tempfile.gettempdir()
result_file = os.path.join(temp_dir, f"factor_calculation_{threading.get_ident()}.json")
result_data = {
'success': True, 'success': True,
'results': [r.to_dict() if hasattr(r, 'to_dict') else str(r) for r in results], 'results': results,
'factors': [f.name for f in selected_factors], 'factors': [f.name for f in selected_factors],
'date_range': f"{start_date}{end_date}", 'date_range': f"{start_date}{end_date}",
'stock_count': len(set(r.ts_code for r in results)) if results else 0, 'stock_count': len(set(r.ts_code for r in results)) if results else 0,
'message': f"因子计算完成!共计算 {len(results)} 条因子记录" 'message': f"因子计算完成!共计算 {len(results)} 条因子记录"
} }
with open(result_file, 'w', encoding='utf-8') as f:
json.dump(result_data, f, ensure_ascii=False, indent=2)
except Exception as e: except Exception as e:
# 计算失败 # 计算失败
import json factor_progress.error_occurred(f"因子计算失败: {str(e)}")
import tempfile return {
import os
temp_dir = tempfile.gettempdir()
result_file = os.path.join(temp_dir, f"factor_calculation_{threading.get_ident()}.json")
error_data = {
'success': False, 'success': False,
'error': str(e), 'error': str(e),
'message': f"因子计算失败: {str(e)}" 'message': f"因子计算失败: {str(e)}"
} }
with open(result_file, 'w', encoding='utf-8') as f:
json.dump(error_data, f, ensure_ascii=False, indent=2)
# 6. 开始计算按钮 # 6. 开始计算按钮
if st.button("开始计算因子", disabled=not selected_factors): if st.button("开始计算因子", disabled=not selected_factors):
# 重置状态 # 重置状态
if 'factor_calculation_thread' in st.session_state:
st.session_state.factor_calculation_thread = None
if 'factor_calculation_results' in st.session_state: if 'factor_calculation_results' in st.session_state:
st.session_state.factor_calculation_results = None st.session_state.factor_calculation_results = None
if 'factor_calculation_error' in st.session_state: if 'factor_calculation_error' in st.session_state:
@ -249,18 +227,21 @@ def render_factor_calculation() -> None:
total_batches=len(selected_factors) total_batches=len(selected_factors)
) )
# 启动异步线程 # 直接调用同步计算函数
thread = threading.Thread(target=run_factor_calculation_async) result = run_factor_calculation_sync()
thread.daemon = True
thread.start()
st.session_state.factor_calculation_thread = thread
st.session_state.factor_calculation_thread_id = thread.ident
# 显示计算中状态 # 处理计算结果
st.info("因子计算已开始,请查看侧边栏进度显示...") if result['success']:
st.session_state.factor_calculation_results = {
# 强制重新运行以显示进度 'results': result['results'],
st.rerun() 'factors': result['factors'],
'date_range': result['date_range'],
'stock_count': result['stock_count']
}
st.success("✅ 因子计算完成!")
else:
st.session_state.factor_calculation_error = result['error']
st.error(f"❌ 因子计算失败: {result['error']}")
# 7. 显示计算结果 # 7. 显示计算结果
if 'factor_calculation_results' in st.session_state and st.session_state.factor_calculation_results: if 'factor_calculation_results' in st.session_state and st.session_state.factor_calculation_results:
@ -298,48 +279,7 @@ def render_factor_calculation() -> None:
else: else:
st.info("没有找到因子计算结果") st.info("没有找到因子计算结果")
# 8. 检查异步线程结果 # 8. 移除异步线程检查逻辑(已改为同步模式)
if 'factor_calculation_thread_id' in st.session_state:
import json
import tempfile
import os
thread_id = st.session_state.factor_calculation_thread_id
temp_dir = tempfile.gettempdir()
result_file = os.path.join(temp_dir, f"factor_calculation_{thread_id}.json")
# 检查结果文件是否存在
if os.path.exists(result_file):
try:
with open(result_file, 'r', encoding='utf-8') as f:
result_data = json.load(f)
# 处理结果
if result_data['success']:
# 计算成功
factor_progress.complete_calculation(result_data['message'])
st.session_state.factor_calculation_results = {
'results': result_data['results'],
'factors': result_data['factors'],
'date_range': result_data['date_range'],
'stock_count': result_data['stock_count']
}
else:
# 计算失败
factor_progress.error_occurred(result_data['message'])
st.session_state.factor_calculation_error = result_data['error']
# 清理临时文件
os.remove(result_file)
# 清除线程状态
st.session_state.factor_calculation_thread_id = None
# 强制重新运行以显示结果
st.rerun()
except Exception as e:
st.error(f"处理计算结果时出错: {str(e)}")
# 9. 显示错误信息 # 9. 显示错误信息
if 'factor_calculation_error' in st.session_state and st.session_state.factor_calculation_error: if 'factor_calculation_error' in st.session_state and st.session_state.factor_calculation_error:

View File

@ -1,8 +1,6 @@
"""股票筛选与评估视图。""" """股票筛选与评估视图。"""
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import threading
import time
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -140,8 +138,6 @@ def render_stock_evaluation() -> None:
# 4. 评估结果 # 4. 评估结果
# 初始化会话状态 # 初始化会话状态
if 'evaluation_thread' not in st.session_state:
st.session_state.evaluation_thread = None
if 'evaluation_results' not in st.session_state: if 'evaluation_results' not in st.session_state:
st.session_state.evaluation_results = None st.session_state.evaluation_results = None
if 'evaluation_status' not in st.session_state: if 'evaluation_status' not in st.session_state:
@ -151,8 +147,8 @@ def render_stock_evaluation() -> None:
if 'progress' not in st.session_state: if 'progress' not in st.session_state:
st.session_state.progress = 0 st.session_state.progress = 0
# 步评估函数 # 步评估函数
def run_evaluation_async(): def run_evaluation_sync():
try: try:
st.session_state.evaluation_status = 'running' st.session_state.evaluation_status = 'running'
results = [] results = []
@ -161,9 +157,6 @@ def render_stock_evaluation() -> None:
st.session_state.current_factor = factor_name st.session_state.current_factor = factor_name
st.session_state.progress = (i / len(selected_factors)) * 100 st.session_state.progress = (i / len(selected_factors)) * 100
# 模拟进度更新实际计算中进度会在evaluate_factor内部更新
time.sleep(0.1) # 让UI有机会更新
performance = evaluate_factor( performance = evaluate_factor(
factor_name, factor_name,
start_date, start_date,
@ -203,14 +196,8 @@ def render_stock_evaluation() -> None:
st.session_state.evaluation_status = 'running' st.session_state.evaluation_status = 'running'
st.session_state.progress = 0 st.session_state.progress = 0
# 启动异步线程 # 直接调用同步评估函数
thread = threading.Thread(target=run_evaluation_async) run_evaluation_sync()
thread.daemon = True
thread.start()
st.session_state.evaluation_thread = thread
# 强制重新运行以显示进度
st.rerun()
# 显示结果 # 显示结果
if st.session_state.evaluation_results: if st.session_state.evaluation_results: