refactor factor calculation to synchronous mode with progress tracking
This commit is contained in:
parent
974fc90fc3
commit
db0afe9c2d
@ -332,13 +332,14 @@ def _compute_batch_factors(
|
||||
# 批次化数据可用性检查
|
||||
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
|
||||
|
||||
# 更新UI进度状态 - 开始处理批次(在异步线程中不直接访问factor_progress)
|
||||
# if total_securities > 0:
|
||||
# factor_progress.update_progress(
|
||||
# current_securities=processed_securities,
|
||||
# current_batch=batch_index + 1,
|
||||
# message=f"开始处理批次 {batch_index + 1}/{total_batches}"
|
||||
# )
|
||||
# 更新UI进度状态 - 开始处理批次
|
||||
if total_securities > 0:
|
||||
from app.ui.progress_state import factor_progress
|
||||
factor_progress.update_progress(
|
||||
current_securities=processed_securities,
|
||||
current_batch=batch_index + 1,
|
||||
message=f"开始处理批次 {batch_index + 1}/{total_batches}"
|
||||
)
|
||||
|
||||
for i, ts_code in enumerate(ts_codes):
|
||||
try:
|
||||
@ -381,6 +382,7 @@ def _compute_batch_factors(
|
||||
if total_securities > 0:
|
||||
current_progress = processed_securities + i + 1
|
||||
progress_percentage = (current_progress / total_securities) * 100
|
||||
from app.ui.progress_state import factor_progress
|
||||
factor_progress.update_progress(
|
||||
current_securities=current_progress,
|
||||
current_batch=batch_index + 1,
|
||||
@ -399,6 +401,7 @@ def _compute_batch_factors(
|
||||
if total_securities > 0:
|
||||
final_progress = processed_securities + len(ts_codes)
|
||||
progress_percentage = (final_progress / total_securities) * 100
|
||||
from app.ui.progress_state import factor_progress
|
||||
factor_progress.update_progress(
|
||||
current_securities=final_progress,
|
||||
current_batch=batch_index + 1,
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
"""因子计算页面。"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional
|
||||
import threading
|
||||
import time
|
||||
|
||||
import streamlit as st
|
||||
|
||||
@ -155,12 +153,9 @@ def render_factor_calculation() -> None:
|
||||
help="如果勾选,将跳过数据库中已存在的因子计算结果"
|
||||
)
|
||||
|
||||
# 5. 异步计算函数
|
||||
def run_factor_calculation_async():
|
||||
"""异步执行因子计算"""
|
||||
# 在异步线程中避免直接访问st.session_state
|
||||
# 使用全局变量或文件来传递进度信息
|
||||
|
||||
# 5. 同步计算函数
|
||||
def run_factor_calculation_sync():
|
||||
"""同步执行因子计算"""
|
||||
# 计算参数
|
||||
total_stocks = len(universe) if universe else len(_get_all_stocks())
|
||||
total_batches = len(selected_factors)
|
||||
@ -169,6 +164,13 @@ def render_factor_calculation() -> None:
|
||||
# 执行因子计算
|
||||
results = []
|
||||
for i, factor in enumerate(selected_factors):
|
||||
# 更新批次进度
|
||||
factor_progress.update_progress(
|
||||
current_securities=0,
|
||||
current_batch=i+1,
|
||||
message=f"正在计算因子: {factor.name}"
|
||||
)
|
||||
|
||||
# 计算单个交易日的因子
|
||||
current_date = start_date
|
||||
while current_date <= end_date:
|
||||
@ -188,55 +190,31 @@ def render_factor_calculation() -> None:
|
||||
print(f"ERROR: {error_msg}")
|
||||
|
||||
current_date += timedelta(days=1)
|
||||
|
||||
# 短暂暂停
|
||||
time.sleep(0.1)
|
||||
|
||||
# 计算完成,通过文件或全局变量传递结果
|
||||
# 这里使用简单的文件方式传递结果
|
||||
import json
|
||||
import tempfile
|
||||
import os
|
||||
# 计算完成
|
||||
factor_progress.complete_calculation(f"因子计算完成!共计算 {len(results)} 条因子记录")
|
||||
|
||||
# 创建临时文件存储结果
|
||||
temp_dir = tempfile.gettempdir()
|
||||
result_file = os.path.join(temp_dir, f"factor_calculation_{threading.get_ident()}.json")
|
||||
|
||||
result_data = {
|
||||
return {
|
||||
'success': True,
|
||||
'results': [r.to_dict() if hasattr(r, 'to_dict') else str(r) for r in results],
|
||||
'results': results,
|
||||
'factors': [f.name for f in selected_factors],
|
||||
'date_range': f"{start_date} 至 {end_date}",
|
||||
'stock_count': len(set(r.ts_code for r in results)) if results else 0,
|
||||
'message': f"因子计算完成!共计算 {len(results)} 条因子记录"
|
||||
}
|
||||
|
||||
with open(result_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(result_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
# 计算失败
|
||||
import json
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
temp_dir = tempfile.gettempdir()
|
||||
result_file = os.path.join(temp_dir, f"factor_calculation_{threading.get_ident()}.json")
|
||||
|
||||
error_data = {
|
||||
factor_progress.error_occurred(f"因子计算失败: {str(e)}")
|
||||
return {
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'message': f"因子计算失败: {str(e)}"
|
||||
}
|
||||
|
||||
with open(result_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(error_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 6. 开始计算按钮
|
||||
if st.button("开始计算因子", disabled=not selected_factors):
|
||||
# 重置状态
|
||||
if 'factor_calculation_thread' in st.session_state:
|
||||
st.session_state.factor_calculation_thread = None
|
||||
if 'factor_calculation_results' in st.session_state:
|
||||
st.session_state.factor_calculation_results = None
|
||||
if 'factor_calculation_error' in st.session_state:
|
||||
@ -249,18 +227,21 @@ def render_factor_calculation() -> None:
|
||||
total_batches=len(selected_factors)
|
||||
)
|
||||
|
||||
# 启动异步线程
|
||||
thread = threading.Thread(target=run_factor_calculation_async)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
st.session_state.factor_calculation_thread = thread
|
||||
st.session_state.factor_calculation_thread_id = thread.ident
|
||||
# 直接调用同步计算函数
|
||||
result = run_factor_calculation_sync()
|
||||
|
||||
# 显示计算中状态
|
||||
st.info("因子计算已开始,请查看侧边栏进度显示...")
|
||||
|
||||
# 强制重新运行以显示进度
|
||||
st.rerun()
|
||||
# 处理计算结果
|
||||
if result['success']:
|
||||
st.session_state.factor_calculation_results = {
|
||||
'results': result['results'],
|
||||
'factors': result['factors'],
|
||||
'date_range': result['date_range'],
|
||||
'stock_count': result['stock_count']
|
||||
}
|
||||
st.success("✅ 因子计算完成!")
|
||||
else:
|
||||
st.session_state.factor_calculation_error = result['error']
|
||||
st.error(f"❌ 因子计算失败: {result['error']}")
|
||||
|
||||
# 7. 显示计算结果
|
||||
if 'factor_calculation_results' in st.session_state and st.session_state.factor_calculation_results:
|
||||
@ -298,48 +279,7 @@ def render_factor_calculation() -> None:
|
||||
else:
|
||||
st.info("没有找到因子计算结果")
|
||||
|
||||
# 8. 检查异步线程结果
|
||||
if 'factor_calculation_thread_id' in st.session_state:
|
||||
import json
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
thread_id = st.session_state.factor_calculation_thread_id
|
||||
temp_dir = tempfile.gettempdir()
|
||||
result_file = os.path.join(temp_dir, f"factor_calculation_{thread_id}.json")
|
||||
|
||||
# 检查结果文件是否存在
|
||||
if os.path.exists(result_file):
|
||||
try:
|
||||
with open(result_file, 'r', encoding='utf-8') as f:
|
||||
result_data = json.load(f)
|
||||
|
||||
# 处理结果
|
||||
if result_data['success']:
|
||||
# 计算成功
|
||||
factor_progress.complete_calculation(result_data['message'])
|
||||
st.session_state.factor_calculation_results = {
|
||||
'results': result_data['results'],
|
||||
'factors': result_data['factors'],
|
||||
'date_range': result_data['date_range'],
|
||||
'stock_count': result_data['stock_count']
|
||||
}
|
||||
else:
|
||||
# 计算失败
|
||||
factor_progress.error_occurred(result_data['message'])
|
||||
st.session_state.factor_calculation_error = result_data['error']
|
||||
|
||||
# 清理临时文件
|
||||
os.remove(result_file)
|
||||
|
||||
# 清除线程状态
|
||||
st.session_state.factor_calculation_thread_id = None
|
||||
|
||||
# 强制重新运行以显示结果
|
||||
st.rerun()
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"处理计算结果时出错: {str(e)}")
|
||||
# 8. 移除异步线程检查逻辑(已改为同步模式)
|
||||
|
||||
# 9. 显示错误信息
|
||||
if 'factor_calculation_error' in st.session_state and st.session_state.factor_calculation_error:
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
"""股票筛选与评估视图。"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -140,8 +138,6 @@ def render_stock_evaluation() -> None:
|
||||
# 4. 评估结果
|
||||
|
||||
# 初始化会话状态
|
||||
if 'evaluation_thread' not in st.session_state:
|
||||
st.session_state.evaluation_thread = None
|
||||
if 'evaluation_results' not in st.session_state:
|
||||
st.session_state.evaluation_results = None
|
||||
if 'evaluation_status' not in st.session_state:
|
||||
@ -151,8 +147,8 @@ def render_stock_evaluation() -> None:
|
||||
if 'progress' not in st.session_state:
|
||||
st.session_state.progress = 0
|
||||
|
||||
# 异步评估函数
|
||||
def run_evaluation_async():
|
||||
# 同步评估函数
|
||||
def run_evaluation_sync():
|
||||
try:
|
||||
st.session_state.evaluation_status = 'running'
|
||||
results = []
|
||||
@ -161,9 +157,6 @@ def render_stock_evaluation() -> None:
|
||||
st.session_state.current_factor = factor_name
|
||||
st.session_state.progress = (i / len(selected_factors)) * 100
|
||||
|
||||
# 模拟进度更新(实际计算中进度会在evaluate_factor内部更新)
|
||||
time.sleep(0.1) # 让UI有机会更新
|
||||
|
||||
performance = evaluate_factor(
|
||||
factor_name,
|
||||
start_date,
|
||||
@ -203,14 +196,8 @@ def render_stock_evaluation() -> None:
|
||||
st.session_state.evaluation_status = 'running'
|
||||
st.session_state.progress = 0
|
||||
|
||||
# 启动异步线程
|
||||
thread = threading.Thread(target=run_evaluation_async)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
st.session_state.evaluation_thread = thread
|
||||
|
||||
# 强制重新运行以显示进度
|
||||
st.rerun()
|
||||
# 直接调用同步评估函数
|
||||
run_evaluation_sync()
|
||||
|
||||
# 显示结果
|
||||
if st.session_state.evaluation_results:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user