From db0afe9c2da4290cfa8d82b616e5d2e70b4b330b Mon Sep 17 00:00:00 2001 From: sam Date: Wed, 8 Oct 2025 21:37:40 +0800 Subject: [PATCH] refactor factor calculation to synchronous mode with progress tracking --- app/features/factors.py | 17 ++-- app/ui/views/factor_calculation.py | 122 ++++++++--------------------- app/ui/views/stock_eval.py | 21 +---- 3 files changed, 45 insertions(+), 115 deletions(-) diff --git a/app/features/factors.py b/app/features/factors.py index e8d804b..a340fbd 100644 --- a/app/features/factors.py +++ b/app/features/factors.py @@ -332,13 +332,14 @@ def _compute_batch_factors( # 批次化数据可用性检查 available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs) - # 更新UI进度状态 - 开始处理批次(在异步线程中不直接访问factor_progress) - # if total_securities > 0: - # factor_progress.update_progress( - # current_securities=processed_securities, - # current_batch=batch_index + 1, - # message=f"开始处理批次 {batch_index + 1}/{total_batches}" - # ) + # 更新UI进度状态 - 开始处理批次 + if total_securities > 0: + from app.ui.progress_state import factor_progress + factor_progress.update_progress( + current_securities=processed_securities, + current_batch=batch_index + 1, + message=f"开始处理批次 {batch_index + 1}/{total_batches}" + ) for i, ts_code in enumerate(ts_codes): try: @@ -381,6 +382,7 @@ def _compute_batch_factors( if total_securities > 0: current_progress = processed_securities + i + 1 progress_percentage = (current_progress / total_securities) * 100 + from app.ui.progress_state import factor_progress factor_progress.update_progress( current_securities=current_progress, current_batch=batch_index + 1, @@ -399,6 +401,7 @@ def _compute_batch_factors( if total_securities > 0: final_progress = processed_securities + len(ts_codes) progress_percentage = (final_progress / total_securities) * 100 + from app.ui.progress_state import factor_progress factor_progress.update_progress( current_securities=final_progress, current_batch=batch_index + 1, diff --git a/app/ui/views/factor_calculation.py b/app/ui/views/factor_calculation.py index 5f08e8c..0c30482 100644 --- a/app/ui/views/factor_calculation.py +++ b/app/ui/views/factor_calculation.py @@ -1,8 +1,6 @@ """因子计算页面。""" from datetime import datetime, timedelta from typing import List, Optional -import threading -import time import streamlit as st @@ -155,12 +153,9 @@ def render_factor_calculation() -> None: help="如果勾选,将跳过数据库中已存在的因子计算结果" ) - # 5. 异步计算函数 - def run_factor_calculation_async(): - """异步执行因子计算""" - # 在异步线程中避免直接访问st.session_state - # 使用全局变量或文件来传递进度信息 - + # 5. 同步计算函数 + def run_factor_calculation_sync(): + """同步执行因子计算""" # 计算参数 total_stocks = len(universe) if universe else len(_get_all_stocks()) total_batches = len(selected_factors) @@ -169,6 +164,13 @@ def render_factor_calculation() -> None: # 执行因子计算 results = [] for i, factor in enumerate(selected_factors): + # 更新批次进度 + factor_progress.update_progress( + current_securities=0, + current_batch=i+1, + message=f"正在计算因子: {factor.name}" + ) + # 计算单个交易日的因子 current_date = start_date while current_date <= end_date: @@ -188,55 +190,31 @@ def render_factor_calculation() -> None: print(f"ERROR: {error_msg}") current_date += timedelta(days=1) - - # 短暂暂停 - time.sleep(0.1) - # 计算完成,通过文件或全局变量传递结果 - # 这里使用简单的文件方式传递结果 - import json - import tempfile - import os + # 计算完成 + factor_progress.complete_calculation(f"因子计算完成!共计算 {len(results)} 条因子记录") - # 创建临时文件存储结果 - temp_dir = tempfile.gettempdir() - result_file = os.path.join(temp_dir, f"factor_calculation_{threading.get_ident()}.json") - - result_data = { + return { 'success': True, - 'results': [r.to_dict() if hasattr(r, 'to_dict') else str(r) for r in results], + 'results': results, 'factors': [f.name for f in selected_factors], 'date_range': f"{start_date} 至 {end_date}", 'stock_count': len(set(r.ts_code for r in results)) if results else 0, 'message': f"因子计算完成!共计算 {len(results)} 条因子记录" } - - with open(result_file, 'w', encoding='utf-8') as f: - json.dump(result_data, f, ensure_ascii=False, indent=2) except Exception as e: # 计算失败 - import json - import tempfile - import os - - temp_dir = tempfile.gettempdir() - result_file = os.path.join(temp_dir, f"factor_calculation_{threading.get_ident()}.json") - - error_data = { + factor_progress.error_occurred(f"因子计算失败: {str(e)}") + return { 'success': False, 'error': str(e), 'message': f"因子计算失败: {str(e)}" } - - with open(result_file, 'w', encoding='utf-8') as f: - json.dump(error_data, f, ensure_ascii=False, indent=2) # 6. 开始计算按钮 if st.button("开始计算因子", disabled=not selected_factors): # 重置状态 - if 'factor_calculation_thread' in st.session_state: - st.session_state.factor_calculation_thread = None if 'factor_calculation_results' in st.session_state: st.session_state.factor_calculation_results = None if 'factor_calculation_error' in st.session_state: @@ -249,18 +227,21 @@ def render_factor_calculation() -> None: total_batches=len(selected_factors) ) - # 启动异步线程 - thread = threading.Thread(target=run_factor_calculation_async) - thread.daemon = True - thread.start() - st.session_state.factor_calculation_thread = thread - st.session_state.factor_calculation_thread_id = thread.ident + # 直接调用同步计算函数 + result = run_factor_calculation_sync() - # 显示计算中状态 - st.info("因子计算已开始,请查看侧边栏进度显示...") - - # 强制重新运行以显示进度 - st.rerun() + # 处理计算结果 + if result['success']: + st.session_state.factor_calculation_results = { + 'results': result['results'], + 'factors': result['factors'], + 'date_range': result['date_range'], + 'stock_count': result['stock_count'] + } + st.success("✅ 因子计算完成!") + else: + st.session_state.factor_calculation_error = result['error'] + st.error(f"❌ 因子计算失败: {result['error']}") # 7. 显示计算结果 if 'factor_calculation_results' in st.session_state and st.session_state.factor_calculation_results: @@ -298,48 +279,7 @@ def render_factor_calculation() -> None: else: st.info("没有找到因子计算结果") - # 8. 检查异步线程结果 - if 'factor_calculation_thread_id' in st.session_state: - import json - import tempfile - import os - - thread_id = st.session_state.factor_calculation_thread_id - temp_dir = tempfile.gettempdir() - result_file = os.path.join(temp_dir, f"factor_calculation_{thread_id}.json") - - # 检查结果文件是否存在 - if os.path.exists(result_file): - try: - with open(result_file, 'r', encoding='utf-8') as f: - result_data = json.load(f) - - # 处理结果 - if result_data['success']: - # 计算成功 - factor_progress.complete_calculation(result_data['message']) - st.session_state.factor_calculation_results = { - 'results': result_data['results'], - 'factors': result_data['factors'], - 'date_range': result_data['date_range'], - 'stock_count': result_data['stock_count'] - } - else: - # 计算失败 - factor_progress.error_occurred(result_data['message']) - st.session_state.factor_calculation_error = result_data['error'] - - # 清理临时文件 - os.remove(result_file) - - # 清除线程状态 - st.session_state.factor_calculation_thread_id = None - - # 强制重新运行以显示结果 - st.rerun() - - except Exception as e: - st.error(f"处理计算结果时出错: {str(e)}") + # 8. 移除异步线程检查逻辑(已改为同步模式) # 9. 显示错误信息 if 'factor_calculation_error' in st.session_state and st.session_state.factor_calculation_error: diff --git a/app/ui/views/stock_eval.py b/app/ui/views/stock_eval.py index 24f9784..e681bea 100644 --- a/app/ui/views/stock_eval.py +++ b/app/ui/views/stock_eval.py @@ -1,8 +1,6 @@ """股票筛选与评估视图。""" from datetime import datetime, timedelta from typing import Dict, List, Optional, Tuple -import threading -import time import numpy as np import pandas as pd @@ -140,8 +138,6 @@ def render_stock_evaluation() -> None: # 4. 评估结果 # 初始化会话状态 - if 'evaluation_thread' not in st.session_state: - st.session_state.evaluation_thread = None if 'evaluation_results' not in st.session_state: st.session_state.evaluation_results = None if 'evaluation_status' not in st.session_state: @@ -151,8 +147,8 @@ def render_stock_evaluation() -> None: if 'progress' not in st.session_state: st.session_state.progress = 0 - # 异步评估函数 - def run_evaluation_async(): + # 同步评估函数 + def run_evaluation_sync(): try: st.session_state.evaluation_status = 'running' results = [] @@ -161,9 +157,6 @@ def render_stock_evaluation() -> None: st.session_state.current_factor = factor_name st.session_state.progress = (i / len(selected_factors)) * 100 - # 模拟进度更新(实际计算中进度会在evaluate_factor内部更新) - time.sleep(0.1) # 让UI有机会更新 - performance = evaluate_factor( factor_name, start_date, @@ -203,14 +196,8 @@ def render_stock_evaluation() -> None: st.session_state.evaluation_status = 'running' st.session_state.progress = 0 - # 启动异步线程 - thread = threading.Thread(target=run_evaluation_async) - thread.daemon = True - thread.start() - st.session_state.evaluation_thread = thread - - # 强制重新运行以显示进度 - st.rerun() + # 直接调用同步评估函数 + run_evaluation_sync() # 显示结果 if st.session_state.evaluation_results: