refactor factor calculation to synchronous mode with progress tracking
This commit is contained in:
parent
974fc90fc3
commit
db0afe9c2d
@ -332,13 +332,14 @@ def _compute_batch_factors(
|
|||||||
# 批次化数据可用性检查
|
# 批次化数据可用性检查
|
||||||
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
|
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
|
||||||
|
|
||||||
# 更新UI进度状态 - 开始处理批次(在异步线程中不直接访问factor_progress)
|
# 更新UI进度状态 - 开始处理批次
|
||||||
# if total_securities > 0:
|
if total_securities > 0:
|
||||||
# factor_progress.update_progress(
|
from app.ui.progress_state import factor_progress
|
||||||
# current_securities=processed_securities,
|
factor_progress.update_progress(
|
||||||
# current_batch=batch_index + 1,
|
current_securities=processed_securities,
|
||||||
# message=f"开始处理批次 {batch_index + 1}/{total_batches}"
|
current_batch=batch_index + 1,
|
||||||
# )
|
message=f"开始处理批次 {batch_index + 1}/{total_batches}"
|
||||||
|
)
|
||||||
|
|
||||||
for i, ts_code in enumerate(ts_codes):
|
for i, ts_code in enumerate(ts_codes):
|
||||||
try:
|
try:
|
||||||
@ -381,6 +382,7 @@ def _compute_batch_factors(
|
|||||||
if total_securities > 0:
|
if total_securities > 0:
|
||||||
current_progress = processed_securities + i + 1
|
current_progress = processed_securities + i + 1
|
||||||
progress_percentage = (current_progress / total_securities) * 100
|
progress_percentage = (current_progress / total_securities) * 100
|
||||||
|
from app.ui.progress_state import factor_progress
|
||||||
factor_progress.update_progress(
|
factor_progress.update_progress(
|
||||||
current_securities=current_progress,
|
current_securities=current_progress,
|
||||||
current_batch=batch_index + 1,
|
current_batch=batch_index + 1,
|
||||||
@ -399,6 +401,7 @@ def _compute_batch_factors(
|
|||||||
if total_securities > 0:
|
if total_securities > 0:
|
||||||
final_progress = processed_securities + len(ts_codes)
|
final_progress = processed_securities + len(ts_codes)
|
||||||
progress_percentage = (final_progress / total_securities) * 100
|
progress_percentage = (final_progress / total_securities) * 100
|
||||||
|
from app.ui.progress_state import factor_progress
|
||||||
factor_progress.update_progress(
|
factor_progress.update_progress(
|
||||||
current_securities=final_progress,
|
current_securities=final_progress,
|
||||||
current_batch=batch_index + 1,
|
current_batch=batch_index + 1,
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
"""因子计算页面。"""
|
"""因子计算页面。"""
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
@ -155,12 +153,9 @@ def render_factor_calculation() -> None:
|
|||||||
help="如果勾选,将跳过数据库中已存在的因子计算结果"
|
help="如果勾选,将跳过数据库中已存在的因子计算结果"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. 异步计算函数
|
# 5. 同步计算函数
|
||||||
def run_factor_calculation_async():
|
def run_factor_calculation_sync():
|
||||||
"""异步执行因子计算"""
|
"""同步执行因子计算"""
|
||||||
# 在异步线程中避免直接访问st.session_state
|
|
||||||
# 使用全局变量或文件来传递进度信息
|
|
||||||
|
|
||||||
# 计算参数
|
# 计算参数
|
||||||
total_stocks = len(universe) if universe else len(_get_all_stocks())
|
total_stocks = len(universe) if universe else len(_get_all_stocks())
|
||||||
total_batches = len(selected_factors)
|
total_batches = len(selected_factors)
|
||||||
@ -169,6 +164,13 @@ def render_factor_calculation() -> None:
|
|||||||
# 执行因子计算
|
# 执行因子计算
|
||||||
results = []
|
results = []
|
||||||
for i, factor in enumerate(selected_factors):
|
for i, factor in enumerate(selected_factors):
|
||||||
|
# 更新批次进度
|
||||||
|
factor_progress.update_progress(
|
||||||
|
current_securities=0,
|
||||||
|
current_batch=i+1,
|
||||||
|
message=f"正在计算因子: {factor.name}"
|
||||||
|
)
|
||||||
|
|
||||||
# 计算单个交易日的因子
|
# 计算单个交易日的因子
|
||||||
current_date = start_date
|
current_date = start_date
|
||||||
while current_date <= end_date:
|
while current_date <= end_date:
|
||||||
@ -189,54 +191,30 @@ def render_factor_calculation() -> None:
|
|||||||
|
|
||||||
current_date += timedelta(days=1)
|
current_date += timedelta(days=1)
|
||||||
|
|
||||||
# 短暂暂停
|
# 计算完成
|
||||||
time.sleep(0.1)
|
factor_progress.complete_calculation(f"因子计算完成!共计算 {len(results)} 条因子记录")
|
||||||
|
|
||||||
# 计算完成,通过文件或全局变量传递结果
|
return {
|
||||||
# 这里使用简单的文件方式传递结果
|
|
||||||
import json
|
|
||||||
import tempfile
|
|
||||||
import os
|
|
||||||
|
|
||||||
# 创建临时文件存储结果
|
|
||||||
temp_dir = tempfile.gettempdir()
|
|
||||||
result_file = os.path.join(temp_dir, f"factor_calculation_{threading.get_ident()}.json")
|
|
||||||
|
|
||||||
result_data = {
|
|
||||||
'success': True,
|
'success': True,
|
||||||
'results': [r.to_dict() if hasattr(r, 'to_dict') else str(r) for r in results],
|
'results': results,
|
||||||
'factors': [f.name for f in selected_factors],
|
'factors': [f.name for f in selected_factors],
|
||||||
'date_range': f"{start_date} 至 {end_date}",
|
'date_range': f"{start_date} 至 {end_date}",
|
||||||
'stock_count': len(set(r.ts_code for r in results)) if results else 0,
|
'stock_count': len(set(r.ts_code for r in results)) if results else 0,
|
||||||
'message': f"因子计算完成!共计算 {len(results)} 条因子记录"
|
'message': f"因子计算完成!共计算 {len(results)} 条因子记录"
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(result_file, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(result_data, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 计算失败
|
# 计算失败
|
||||||
import json
|
factor_progress.error_occurred(f"因子计算失败: {str(e)}")
|
||||||
import tempfile
|
return {
|
||||||
import os
|
|
||||||
|
|
||||||
temp_dir = tempfile.gettempdir()
|
|
||||||
result_file = os.path.join(temp_dir, f"factor_calculation_{threading.get_ident()}.json")
|
|
||||||
|
|
||||||
error_data = {
|
|
||||||
'success': False,
|
'success': False,
|
||||||
'error': str(e),
|
'error': str(e),
|
||||||
'message': f"因子计算失败: {str(e)}"
|
'message': f"因子计算失败: {str(e)}"
|
||||||
}
|
}
|
||||||
|
|
||||||
with open(result_file, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(error_data, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
# 6. 开始计算按钮
|
# 6. 开始计算按钮
|
||||||
if st.button("开始计算因子", disabled=not selected_factors):
|
if st.button("开始计算因子", disabled=not selected_factors):
|
||||||
# 重置状态
|
# 重置状态
|
||||||
if 'factor_calculation_thread' in st.session_state:
|
|
||||||
st.session_state.factor_calculation_thread = None
|
|
||||||
if 'factor_calculation_results' in st.session_state:
|
if 'factor_calculation_results' in st.session_state:
|
||||||
st.session_state.factor_calculation_results = None
|
st.session_state.factor_calculation_results = None
|
||||||
if 'factor_calculation_error' in st.session_state:
|
if 'factor_calculation_error' in st.session_state:
|
||||||
@ -249,18 +227,21 @@ def render_factor_calculation() -> None:
|
|||||||
total_batches=len(selected_factors)
|
total_batches=len(selected_factors)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 启动异步线程
|
# 直接调用同步计算函数
|
||||||
thread = threading.Thread(target=run_factor_calculation_async)
|
result = run_factor_calculation_sync()
|
||||||
thread.daemon = True
|
|
||||||
thread.start()
|
|
||||||
st.session_state.factor_calculation_thread = thread
|
|
||||||
st.session_state.factor_calculation_thread_id = thread.ident
|
|
||||||
|
|
||||||
# 显示计算中状态
|
# 处理计算结果
|
||||||
st.info("因子计算已开始,请查看侧边栏进度显示...")
|
if result['success']:
|
||||||
|
st.session_state.factor_calculation_results = {
|
||||||
# 强制重新运行以显示进度
|
'results': result['results'],
|
||||||
st.rerun()
|
'factors': result['factors'],
|
||||||
|
'date_range': result['date_range'],
|
||||||
|
'stock_count': result['stock_count']
|
||||||
|
}
|
||||||
|
st.success("✅ 因子计算完成!")
|
||||||
|
else:
|
||||||
|
st.session_state.factor_calculation_error = result['error']
|
||||||
|
st.error(f"❌ 因子计算失败: {result['error']}")
|
||||||
|
|
||||||
# 7. 显示计算结果
|
# 7. 显示计算结果
|
||||||
if 'factor_calculation_results' in st.session_state and st.session_state.factor_calculation_results:
|
if 'factor_calculation_results' in st.session_state and st.session_state.factor_calculation_results:
|
||||||
@ -298,48 +279,7 @@ def render_factor_calculation() -> None:
|
|||||||
else:
|
else:
|
||||||
st.info("没有找到因子计算结果")
|
st.info("没有找到因子计算结果")
|
||||||
|
|
||||||
# 8. 检查异步线程结果
|
# 8. 移除异步线程检查逻辑(已改为同步模式)
|
||||||
if 'factor_calculation_thread_id' in st.session_state:
|
|
||||||
import json
|
|
||||||
import tempfile
|
|
||||||
import os
|
|
||||||
|
|
||||||
thread_id = st.session_state.factor_calculation_thread_id
|
|
||||||
temp_dir = tempfile.gettempdir()
|
|
||||||
result_file = os.path.join(temp_dir, f"factor_calculation_{thread_id}.json")
|
|
||||||
|
|
||||||
# 检查结果文件是否存在
|
|
||||||
if os.path.exists(result_file):
|
|
||||||
try:
|
|
||||||
with open(result_file, 'r', encoding='utf-8') as f:
|
|
||||||
result_data = json.load(f)
|
|
||||||
|
|
||||||
# 处理结果
|
|
||||||
if result_data['success']:
|
|
||||||
# 计算成功
|
|
||||||
factor_progress.complete_calculation(result_data['message'])
|
|
||||||
st.session_state.factor_calculation_results = {
|
|
||||||
'results': result_data['results'],
|
|
||||||
'factors': result_data['factors'],
|
|
||||||
'date_range': result_data['date_range'],
|
|
||||||
'stock_count': result_data['stock_count']
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# 计算失败
|
|
||||||
factor_progress.error_occurred(result_data['message'])
|
|
||||||
st.session_state.factor_calculation_error = result_data['error']
|
|
||||||
|
|
||||||
# 清理临时文件
|
|
||||||
os.remove(result_file)
|
|
||||||
|
|
||||||
# 清除线程状态
|
|
||||||
st.session_state.factor_calculation_thread_id = None
|
|
||||||
|
|
||||||
# 强制重新运行以显示结果
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"处理计算结果时出错: {str(e)}")
|
|
||||||
|
|
||||||
# 9. 显示错误信息
|
# 9. 显示错误信息
|
||||||
if 'factor_calculation_error' in st.session_state and st.session_state.factor_calculation_error:
|
if 'factor_calculation_error' in st.session_state and st.session_state.factor_calculation_error:
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
"""股票筛选与评估视图。"""
|
"""股票筛选与评估视图。"""
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -140,8 +138,6 @@ def render_stock_evaluation() -> None:
|
|||||||
# 4. 评估结果
|
# 4. 评估结果
|
||||||
|
|
||||||
# 初始化会话状态
|
# 初始化会话状态
|
||||||
if 'evaluation_thread' not in st.session_state:
|
|
||||||
st.session_state.evaluation_thread = None
|
|
||||||
if 'evaluation_results' not in st.session_state:
|
if 'evaluation_results' not in st.session_state:
|
||||||
st.session_state.evaluation_results = None
|
st.session_state.evaluation_results = None
|
||||||
if 'evaluation_status' not in st.session_state:
|
if 'evaluation_status' not in st.session_state:
|
||||||
@ -151,8 +147,8 @@ def render_stock_evaluation() -> None:
|
|||||||
if 'progress' not in st.session_state:
|
if 'progress' not in st.session_state:
|
||||||
st.session_state.progress = 0
|
st.session_state.progress = 0
|
||||||
|
|
||||||
# 异步评估函数
|
# 同步评估函数
|
||||||
def run_evaluation_async():
|
def run_evaluation_sync():
|
||||||
try:
|
try:
|
||||||
st.session_state.evaluation_status = 'running'
|
st.session_state.evaluation_status = 'running'
|
||||||
results = []
|
results = []
|
||||||
@ -161,9 +157,6 @@ def render_stock_evaluation() -> None:
|
|||||||
st.session_state.current_factor = factor_name
|
st.session_state.current_factor = factor_name
|
||||||
st.session_state.progress = (i / len(selected_factors)) * 100
|
st.session_state.progress = (i / len(selected_factors)) * 100
|
||||||
|
|
||||||
# 模拟进度更新(实际计算中进度会在evaluate_factor内部更新)
|
|
||||||
time.sleep(0.1) # 让UI有机会更新
|
|
||||||
|
|
||||||
performance = evaluate_factor(
|
performance = evaluate_factor(
|
||||||
factor_name,
|
factor_name,
|
||||||
start_date,
|
start_date,
|
||||||
@ -203,14 +196,8 @@ def render_stock_evaluation() -> None:
|
|||||||
st.session_state.evaluation_status = 'running'
|
st.session_state.evaluation_status = 'running'
|
||||||
st.session_state.progress = 0
|
st.session_state.progress = 0
|
||||||
|
|
||||||
# 启动异步线程
|
# 直接调用同步评估函数
|
||||||
thread = threading.Thread(target=run_evaluation_async)
|
run_evaluation_sync()
|
||||||
thread.daemon = True
|
|
||||||
thread.start()
|
|
||||||
st.session_state.evaluation_thread = thread
|
|
||||||
|
|
||||||
# 强制重新运行以显示进度
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
# 显示结果
|
# 显示结果
|
||||||
if st.session_state.evaluation_results:
|
if st.session_state.evaluation_results:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user