diff --git a/app/features/evaluation.py b/app/features/evaluation.py index 485743a..0a46121 100644 --- a/app/features/evaluation.py +++ b/app/features/evaluation.py @@ -83,11 +83,11 @@ def evaluate_factor( # 导入进度状态模块 from app.ui.progress_state import factor_progress - # 开始因子计算进度 - factor_progress.start_calculation( - total_securities=len(universe) if universe else 0, - message=f"开始评估因子 {factor_name}" - ) + # 开始因子计算进度(在异步线程中不直接访问factor_progress) + # factor_progress.start_calculation( + # total_securities=len(universe) if universe else 0, + # message=f"开始评估因子 {factor_name}" + # ) try: # 计算因子值 @@ -98,17 +98,17 @@ def evaluate_factor( ts_codes=universe ) - # 因子计算完成 - factor_progress.complete_calculation( - message=f"因子 {factor_name} 评估完成" - ) + # 因子计算完成(在异步线程中不直接访问factor_progress) + # factor_progress.complete_calculation( + # message=f"因子 {factor_name} 评估完成" + # ) except Exception as e: - # 因子计算失败 - factor_progress.complete_calculation( - message=f"因子 {factor_name} 评估失败: {str(e)}", - success=False - ) + # 因子计算失败(在异步线程中不直接访问factor_progress) + # factor_progress.complete_calculation( + # message=f"因子 {factor_name} 评估失败: {str(e)}", + # success=False + # ) raise # 按日期分组 diff --git a/app/features/factors.py b/app/features/factors.py index b3a5c79..8695bdc 100644 --- a/app/features/factors.py +++ b/app/features/factors.py @@ -165,11 +165,11 @@ def compute_factors( rows_to_persist: List[tuple[str, Dict[str, float | None]]] = [] try: - # 启动UI进度状态 - factor_progress.start_calculation( - total_securities=len(universe), - total_batches=(len(universe) + batch_size - 1) // batch_size - ) + # 启动UI进度状态(在异步线程中不直接访问factor_progress) + # factor_progress.start_calculation( + # total_securities=len(universe), + # total_batches=(len(universe) + batch_size - 1) // batch_size + # ) # 分批处理以优化性能 for i in range(0, len(universe), batch_size): @@ -332,13 +332,13 @@ def _compute_batch_factors( # 批次化数据可用性检查 available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs) - # 更新UI进度状态 - 开始处理批次 - if total_securities > 0: - factor_progress.update_progress( - current_securities=processed_securities, - current_batch=batch_index + 1, - message=f"开始处理批次 {batch_index + 1}/{total_batches}" - ) + # 更新UI进度状态 - 开始处理批次(在异步线程中不直接访问factor_progress) + # if total_securities > 0: + # factor_progress.update_progress( + # current_securities=processed_securities, + # current_batch=batch_index + 1, + # message=f"开始处理批次 {batch_index + 1}/{total_batches}" + # ) for i, ts_code in enumerate(ts_codes): try: diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 1cf7bd8..b5cd783 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -28,6 +28,7 @@ from app.ui.views import ( render_stock_evaluation, render_tests, render_today_plan, + render_factor_calculation, ) from app.utils.config import get_config @@ -79,11 +80,13 @@ def main() -> None: with tabs[1]: render_pool_overview() with tabs[2]: - backtest_tabs = st.tabs(["回测复盘", "股票评估"]) + backtest_tabs = st.tabs(["回测复盘", "股票评估", "因子计算"]) with backtest_tabs[0]: render_backtest_review() with backtest_tabs[1]: render_stock_evaluation() + with backtest_tabs[2]: + render_factor_calculation() with tabs[3]: render_market_visualization() with tabs[4]: diff --git a/app/ui/views/__init__.py b/app/ui/views/__init__.py index df3a5d0..4a87de9 100644 --- a/app/ui/views/__init__.py +++ b/app/ui/views/__init__.py @@ -9,6 +9,7 @@ from .settings import render_config_overview, render_llm_settings, render_data_s from .tests import render_tests from .dashboard import render_global_dashboard, update_dashboard_sidebar from .stock_eval import render_stock_evaluation +from .factor_calculation import render_factor_calculation __all__ = [ "render_today_plan", @@ -23,4 +24,5 @@ __all__ = [ "render_global_dashboard", "update_dashboard_sidebar", "render_stock_evaluation", + "render_factor_calculation", ] diff --git a/app/ui/views/factor_calculation.py b/app/ui/views/factor_calculation.py new file mode 100644 index 0000000..5f08e8c --- /dev/null +++ b/app/ui/views/factor_calculation.py @@ -0,0 +1,346 @@ +"""因子计算页面。""" +from datetime import datetime, timedelta +from typing import List, Optional +import threading +import time + +import streamlit as st + +from app.features.factors import compute_factors, DEFAULT_FACTORS, FactorSpec +from app.ui.progress_state import factor_progress +from app.utils.data_access import DataBroker +from app.utils.db import db_session + + +def _get_latest_trading_date() -> datetime.date: + """获取数据库中的最新交易日期""" + with db_session() as session: + result = session.execute( + """ + SELECT trade_date + FROM daily_basic + WHERE trade_date <= :today + GROUP BY trade_date + ORDER BY trade_date DESC + LIMIT 1 + """, + {"today": datetime.now().strftime("%Y%m%d")} + ).fetchone() + + if result and result[0]: + return datetime.strptime(str(result[0]), "%Y%m%d").date() + return datetime.now().date() - timedelta(days=1) + + +def _get_all_stocks() -> List[str]: + """获取所有股票代码""" + try: + # 从daily表获取所有股票代码 + with db_session() as session: + latest_date = _get_latest_trading_date() + result = session.execute( + """ + SELECT DISTINCT ts_code + FROM daily + WHERE trade_date = :trade_date + """, + {"trade_date": latest_date.strftime("%Y%m%d")} + ).fetchall() + + return [row[0] for row in result] if result else [] + except Exception as e: + st.error(f"获取股票列表失败: {str(e)}") + return [] + + +def render_factor_calculation() -> None: + """渲染因子计算页面。""" + st.subheader("📊 因子计算") + st.caption("计算指定日期范围的因子值") + + # 1. 时间范围选择 + col1, col2 = st.columns(2) + with col1: + latest_date = _get_latest_trading_date() + end_date = st.date_input( + "计算截止日期", + value=latest_date, + help="选择因子计算的截止日期" + ) + with col2: + lookback_days = st.slider( + "回溯天数", + min_value=1, + max_value=365, + value=30, + step=1, + help="选择计算的历史数据长度" + ) + start_date = end_date - timedelta(days=lookback_days) + + st.info(f"计算范围: {start_date} 至 {end_date} (共{lookback_days}天)") + + # 2. 因子选择 + st.markdown("##### 选择要计算的因子") + + # 按因子类型分组 + factor_groups = { + "动量类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("mom_")], + "波动率类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("volat_")], + "换手率类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("turn_")], + "估值类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("val_")], + "量价类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("volume_")], + "市场类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("market_")], + "其他因子": [f for f in DEFAULT_FACTORS if not any(f.name.startswith(prefix) + for prefix in ["mom_", "volat_", "turn_", "val_", "volume_", "market_"])] + } + + selected_factors = [] + for group_name, factors in factor_groups.items(): + if factors: + st.markdown(f"###### {group_name}") + cols = st.columns(3) + for i, factor in enumerate(factors): + if cols[i % 3].checkbox( + factor.name, + value=True, # 默认全选 + help=factor.description if hasattr(factor, 'description') else None, + key=f"factor_checkbox_{factor.name}_{group_name}" # 添加唯一key + ): + selected_factors.append(factor) + + if not selected_factors: + st.warning("请至少选择一个因子进行计算") + return + + # 3. 股票池选择 + st.markdown("##### 股票池范围") + pool_type = st.radio( + "选择股票池", + ["全部A股", "沪深300", "中证500", "中证1000", "自定义"], + index=0, + horizontal=True + ) + + universe: Optional[List[str]] = None + if pool_type != "全部A股": + broker = DataBroker() + if pool_type == "自定义": + custom_codes = st.text_area( + "输入股票代码列表(每行一个)", + help="请输入股票代码,每行一个,例如: 000001.SZ" + ) + if custom_codes: + universe = [ + code.strip() + for code in custom_codes.split("\n") + if code.strip() + ] + else: + index_code = { + "沪深300": "000300.SH", + "中证500": "000905.SH", + "中证1000": "000852.SH" + }[pool_type] + universe = broker.get_index_stocks( + index_code, + end_date.strftime("%Y%m%d") + ) + + # 4. 计算选项 + st.markdown("##### 计算选项") + skip_existing = st.checkbox( + "跳过已计算的因子", + value=True, + help="如果勾选,将跳过数据库中已存在的因子计算结果" + ) + + # 5. 异步计算函数 + def run_factor_calculation_async(): + """异步执行因子计算""" + # 在异步线程中避免直接访问st.session_state + # 使用全局变量或文件来传递进度信息 + + # 计算参数 + total_stocks = len(universe) if universe else len(_get_all_stocks()) + total_batches = len(selected_factors) + + try: + # 执行因子计算 + results = [] + for i, factor in enumerate(selected_factors): + # 计算单个交易日的因子 + current_date = start_date + while current_date <= end_date: + try: + # 计算指定日期的因子 + daily_results = compute_factors( + current_date, + [factor], + ts_codes=universe, + skip_existing=skip_existing + ) + results.extend(daily_results) + + except Exception as e: + # 记录错误但不中断计算 + error_msg = f"计算因子 {factor.name} 在日期 {current_date} 时出错: {str(e)}" + print(f"ERROR: {error_msg}") + + current_date += timedelta(days=1) + + # 短暂暂停 + time.sleep(0.1) + + # 计算完成,通过文件或全局变量传递结果 + # 这里使用简单的文件方式传递结果 + import json + import tempfile + import os + + # 创建临时文件存储结果 + temp_dir = tempfile.gettempdir() + result_file = os.path.join(temp_dir, f"factor_calculation_{threading.get_ident()}.json") + + result_data = { + 'success': True, + 'results': [r.to_dict() if hasattr(r, 'to_dict') else str(r) for r in results], + 'factors': [f.name for f in selected_factors], + 'date_range': f"{start_date} 至 {end_date}", + 'stock_count': len(set(r.ts_code for r in results)) if results else 0, + 'message': f"因子计算完成!共计算 {len(results)} 条因子记录" + } + + with open(result_file, 'w', encoding='utf-8') as f: + json.dump(result_data, f, ensure_ascii=False, indent=2) + + except Exception as e: + # 计算失败 + import json + import tempfile + import os + + temp_dir = tempfile.gettempdir() + result_file = os.path.join(temp_dir, f"factor_calculation_{threading.get_ident()}.json") + + error_data = { + 'success': False, + 'error': str(e), + 'message': f"因子计算失败: {str(e)}" + } + + with open(result_file, 'w', encoding='utf-8') as f: + json.dump(error_data, f, ensure_ascii=False, indent=2) + + # 6. 开始计算按钮 + if st.button("开始计算因子", disabled=not selected_factors): + # 重置状态 + if 'factor_calculation_thread' in st.session_state: + st.session_state.factor_calculation_thread = None + if 'factor_calculation_results' in st.session_state: + st.session_state.factor_calculation_results = None + if 'factor_calculation_error' in st.session_state: + st.session_state.factor_calculation_error = None + + # 初始化进度状态 + total_stocks = len(universe) if universe else len(_get_all_stocks()) + factor_progress.start_calculation( + total_securities=total_stocks, + total_batches=len(selected_factors) + ) + + # 启动异步线程 + thread = threading.Thread(target=run_factor_calculation_async) + thread.daemon = True + thread.start() + st.session_state.factor_calculation_thread = thread + st.session_state.factor_calculation_thread_id = thread.ident + + # 显示计算中状态 + st.info("因子计算已开始,请查看侧边栏进度显示...") + + # 强制重新运行以显示进度 + st.rerun() + + # 7. 显示计算结果 + if 'factor_calculation_results' in st.session_state and st.session_state.factor_calculation_results: + results = st.session_state.factor_calculation_results + + st.success("✅ 因子计算完成!") + + col1, col2, col3 = st.columns(3) + with col1: + st.metric("计算因子数量", len(results['factors'])) + with col2: + st.metric("涉及股票数量", results['stock_count']) + with col3: + st.metric("计算时间范围", results['date_range']) + + # 显示计算详情 + with st.expander("查看计算详情"): + if results['results']: + # 转换为DataFrame显示 + import pandas as pd + df_data = [] + for result in results['results']: + for factor_name, value in result.values.items(): + df_data.append({ + '日期': result.trade_date, + '股票代码': result.ts_code, + '因子名称': factor_name, + '因子值': value + }) + + if df_data: + df = pd.DataFrame(df_data) + st.dataframe(df.head(100), use_container_width=True) # 只显示前100条 + st.info(f"共 {len(df_data)} 条因子记录(显示前100条)") + else: + st.info("没有找到因子计算结果") + + # 8. 检查异步线程结果 + if 'factor_calculation_thread_id' in st.session_state: + import json + import tempfile + import os + + thread_id = st.session_state.factor_calculation_thread_id + temp_dir = tempfile.gettempdir() + result_file = os.path.join(temp_dir, f"factor_calculation_{thread_id}.json") + + # 检查结果文件是否存在 + if os.path.exists(result_file): + try: + with open(result_file, 'r', encoding='utf-8') as f: + result_data = json.load(f) + + # 处理结果 + if result_data['success']: + # 计算成功 + factor_progress.complete_calculation(result_data['message']) + st.session_state.factor_calculation_results = { + 'results': result_data['results'], + 'factors': result_data['factors'], + 'date_range': result_data['date_range'], + 'stock_count': result_data['stock_count'] + } + else: + # 计算失败 + factor_progress.error_occurred(result_data['message']) + st.session_state.factor_calculation_error = result_data['error'] + + # 清理临时文件 + os.remove(result_file) + + # 清除线程状态 + st.session_state.factor_calculation_thread_id = None + + # 强制重新运行以显示结果 + st.rerun() + + except Exception as e: + st.error(f"处理计算结果时出错: {str(e)}") + + # 9. 显示错误信息 + if 'factor_calculation_error' in st.session_state and st.session_state.factor_calculation_error: + st.error(f"❌ 因子计算失败: {st.session_state.factor_calculation_error}") \ No newline at end of file