disable direct progress updates in async threads and add factor calculation view
This commit is contained in:
parent
07c76d7674
commit
85e7483286
@ -83,11 +83,11 @@ def evaluate_factor(
|
|||||||
# 导入进度状态模块
|
# 导入进度状态模块
|
||||||
from app.ui.progress_state import factor_progress
|
from app.ui.progress_state import factor_progress
|
||||||
|
|
||||||
# 开始因子计算进度
|
# 开始因子计算进度(在异步线程中不直接访问factor_progress)
|
||||||
factor_progress.start_calculation(
|
# factor_progress.start_calculation(
|
||||||
total_securities=len(universe) if universe else 0,
|
# total_securities=len(universe) if universe else 0,
|
||||||
message=f"开始评估因子 {factor_name}"
|
# message=f"开始评估因子 {factor_name}"
|
||||||
)
|
# )
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 计算因子值
|
# 计算因子值
|
||||||
@ -98,17 +98,17 @@ def evaluate_factor(
|
|||||||
ts_codes=universe
|
ts_codes=universe
|
||||||
)
|
)
|
||||||
|
|
||||||
# 因子计算完成
|
# 因子计算完成(在异步线程中不直接访问factor_progress)
|
||||||
factor_progress.complete_calculation(
|
# factor_progress.complete_calculation(
|
||||||
message=f"因子 {factor_name} 评估完成"
|
# message=f"因子 {factor_name} 评估完成"
|
||||||
)
|
# )
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 因子计算失败
|
# 因子计算失败(在异步线程中不直接访问factor_progress)
|
||||||
factor_progress.complete_calculation(
|
# factor_progress.complete_calculation(
|
||||||
message=f"因子 {factor_name} 评估失败: {str(e)}",
|
# message=f"因子 {factor_name} 评估失败: {str(e)}",
|
||||||
success=False
|
# success=False
|
||||||
)
|
# )
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# 按日期分组
|
# 按日期分组
|
||||||
|
|||||||
@ -165,11 +165,11 @@ def compute_factors(
|
|||||||
rows_to_persist: List[tuple[str, Dict[str, float | None]]] = []
|
rows_to_persist: List[tuple[str, Dict[str, float | None]]] = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 启动UI进度状态
|
# 启动UI进度状态(在异步线程中不直接访问factor_progress)
|
||||||
factor_progress.start_calculation(
|
# factor_progress.start_calculation(
|
||||||
total_securities=len(universe),
|
# total_securities=len(universe),
|
||||||
total_batches=(len(universe) + batch_size - 1) // batch_size
|
# total_batches=(len(universe) + batch_size - 1) // batch_size
|
||||||
)
|
# )
|
||||||
|
|
||||||
# 分批处理以优化性能
|
# 分批处理以优化性能
|
||||||
for i in range(0, len(universe), batch_size):
|
for i in range(0, len(universe), batch_size):
|
||||||
@ -332,13 +332,13 @@ def _compute_batch_factors(
|
|||||||
# 批次化数据可用性检查
|
# 批次化数据可用性检查
|
||||||
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
|
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
|
||||||
|
|
||||||
# 更新UI进度状态 - 开始处理批次
|
# 更新UI进度状态 - 开始处理批次(在异步线程中不直接访问factor_progress)
|
||||||
if total_securities > 0:
|
# if total_securities > 0:
|
||||||
factor_progress.update_progress(
|
# factor_progress.update_progress(
|
||||||
current_securities=processed_securities,
|
# current_securities=processed_securities,
|
||||||
current_batch=batch_index + 1,
|
# current_batch=batch_index + 1,
|
||||||
message=f"开始处理批次 {batch_index + 1}/{total_batches}"
|
# message=f"开始处理批次 {batch_index + 1}/{total_batches}"
|
||||||
)
|
# )
|
||||||
|
|
||||||
for i, ts_code in enumerate(ts_codes):
|
for i, ts_code in enumerate(ts_codes):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from app.ui.views import (
|
|||||||
render_stock_evaluation,
|
render_stock_evaluation,
|
||||||
render_tests,
|
render_tests,
|
||||||
render_today_plan,
|
render_today_plan,
|
||||||
|
render_factor_calculation,
|
||||||
)
|
)
|
||||||
from app.utils.config import get_config
|
from app.utils.config import get_config
|
||||||
|
|
||||||
@ -79,11 +80,13 @@ def main() -> None:
|
|||||||
with tabs[1]:
|
with tabs[1]:
|
||||||
render_pool_overview()
|
render_pool_overview()
|
||||||
with tabs[2]:
|
with tabs[2]:
|
||||||
backtest_tabs = st.tabs(["回测复盘", "股票评估"])
|
backtest_tabs = st.tabs(["回测复盘", "股票评估", "因子计算"])
|
||||||
with backtest_tabs[0]:
|
with backtest_tabs[0]:
|
||||||
render_backtest_review()
|
render_backtest_review()
|
||||||
with backtest_tabs[1]:
|
with backtest_tabs[1]:
|
||||||
render_stock_evaluation()
|
render_stock_evaluation()
|
||||||
|
with backtest_tabs[2]:
|
||||||
|
render_factor_calculation()
|
||||||
with tabs[3]:
|
with tabs[3]:
|
||||||
render_market_visualization()
|
render_market_visualization()
|
||||||
with tabs[4]:
|
with tabs[4]:
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from .settings import render_config_overview, render_llm_settings, render_data_s
|
|||||||
from .tests import render_tests
|
from .tests import render_tests
|
||||||
from .dashboard import render_global_dashboard, update_dashboard_sidebar
|
from .dashboard import render_global_dashboard, update_dashboard_sidebar
|
||||||
from .stock_eval import render_stock_evaluation
|
from .stock_eval import render_stock_evaluation
|
||||||
|
from .factor_calculation import render_factor_calculation
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"render_today_plan",
|
"render_today_plan",
|
||||||
@ -23,4 +24,5 @@ __all__ = [
|
|||||||
"render_global_dashboard",
|
"render_global_dashboard",
|
||||||
"update_dashboard_sidebar",
|
"update_dashboard_sidebar",
|
||||||
"render_stock_evaluation",
|
"render_stock_evaluation",
|
||||||
|
"render_factor_calculation",
|
||||||
]
|
]
|
||||||
|
|||||||
346
app/ui/views/factor_calculation.py
Normal file
346
app/ui/views/factor_calculation.py
Normal file
@ -0,0 +1,346 @@
|
|||||||
|
"""因子计算页面。"""
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import List, Optional
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
from app.features.factors import compute_factors, DEFAULT_FACTORS, FactorSpec
|
||||||
|
from app.ui.progress_state import factor_progress
|
||||||
|
from app.utils.data_access import DataBroker
|
||||||
|
from app.utils.db import db_session
|
||||||
|
|
||||||
|
|
||||||
|
def _get_latest_trading_date() -> datetime.date:
|
||||||
|
"""获取数据库中的最新交易日期"""
|
||||||
|
with db_session() as session:
|
||||||
|
result = session.execute(
|
||||||
|
"""
|
||||||
|
SELECT trade_date
|
||||||
|
FROM daily_basic
|
||||||
|
WHERE trade_date <= :today
|
||||||
|
GROUP BY trade_date
|
||||||
|
ORDER BY trade_date DESC
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
{"today": datetime.now().strftime("%Y%m%d")}
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if result and result[0]:
|
||||||
|
return datetime.strptime(str(result[0]), "%Y%m%d").date()
|
||||||
|
return datetime.now().date() - timedelta(days=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_all_stocks() -> List[str]:
|
||||||
|
"""获取所有股票代码"""
|
||||||
|
try:
|
||||||
|
# 从daily表获取所有股票代码
|
||||||
|
with db_session() as session:
|
||||||
|
latest_date = _get_latest_trading_date()
|
||||||
|
result = session.execute(
|
||||||
|
"""
|
||||||
|
SELECT DISTINCT ts_code
|
||||||
|
FROM daily
|
||||||
|
WHERE trade_date = :trade_date
|
||||||
|
""",
|
||||||
|
{"trade_date": latest_date.strftime("%Y%m%d")}
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
return [row[0] for row in result] if result else []
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"获取股票列表失败: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def render_factor_calculation() -> None:
|
||||||
|
"""渲染因子计算页面。"""
|
||||||
|
st.subheader("📊 因子计算")
|
||||||
|
st.caption("计算指定日期范围的因子值")
|
||||||
|
|
||||||
|
# 1. 时间范围选择
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
with col1:
|
||||||
|
latest_date = _get_latest_trading_date()
|
||||||
|
end_date = st.date_input(
|
||||||
|
"计算截止日期",
|
||||||
|
value=latest_date,
|
||||||
|
help="选择因子计算的截止日期"
|
||||||
|
)
|
||||||
|
with col2:
|
||||||
|
lookback_days = st.slider(
|
||||||
|
"回溯天数",
|
||||||
|
min_value=1,
|
||||||
|
max_value=365,
|
||||||
|
value=30,
|
||||||
|
step=1,
|
||||||
|
help="选择计算的历史数据长度"
|
||||||
|
)
|
||||||
|
start_date = end_date - timedelta(days=lookback_days)
|
||||||
|
|
||||||
|
st.info(f"计算范围: {start_date} 至 {end_date} (共{lookback_days}天)")
|
||||||
|
|
||||||
|
# 2. 因子选择
|
||||||
|
st.markdown("##### 选择要计算的因子")
|
||||||
|
|
||||||
|
# 按因子类型分组
|
||||||
|
factor_groups = {
|
||||||
|
"动量类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("mom_")],
|
||||||
|
"波动率类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("volat_")],
|
||||||
|
"换手率类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("turn_")],
|
||||||
|
"估值类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("val_")],
|
||||||
|
"量价类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("volume_")],
|
||||||
|
"市场类因子": [f for f in DEFAULT_FACTORS if f.name.startswith("market_")],
|
||||||
|
"其他因子": [f for f in DEFAULT_FACTORS if not any(f.name.startswith(prefix)
|
||||||
|
for prefix in ["mom_", "volat_", "turn_", "val_", "volume_", "market_"])]
|
||||||
|
}
|
||||||
|
|
||||||
|
selected_factors = []
|
||||||
|
for group_name, factors in factor_groups.items():
|
||||||
|
if factors:
|
||||||
|
st.markdown(f"###### {group_name}")
|
||||||
|
cols = st.columns(3)
|
||||||
|
for i, factor in enumerate(factors):
|
||||||
|
if cols[i % 3].checkbox(
|
||||||
|
factor.name,
|
||||||
|
value=True, # 默认全选
|
||||||
|
help=factor.description if hasattr(factor, 'description') else None,
|
||||||
|
key=f"factor_checkbox_{factor.name}_{group_name}" # 添加唯一key
|
||||||
|
):
|
||||||
|
selected_factors.append(factor)
|
||||||
|
|
||||||
|
if not selected_factors:
|
||||||
|
st.warning("请至少选择一个因子进行计算")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 3. 股票池选择
|
||||||
|
st.markdown("##### 股票池范围")
|
||||||
|
pool_type = st.radio(
|
||||||
|
"选择股票池",
|
||||||
|
["全部A股", "沪深300", "中证500", "中证1000", "自定义"],
|
||||||
|
index=0,
|
||||||
|
horizontal=True
|
||||||
|
)
|
||||||
|
|
||||||
|
universe: Optional[List[str]] = None
|
||||||
|
if pool_type != "全部A股":
|
||||||
|
broker = DataBroker()
|
||||||
|
if pool_type == "自定义":
|
||||||
|
custom_codes = st.text_area(
|
||||||
|
"输入股票代码列表(每行一个)",
|
||||||
|
help="请输入股票代码,每行一个,例如: 000001.SZ"
|
||||||
|
)
|
||||||
|
if custom_codes:
|
||||||
|
universe = [
|
||||||
|
code.strip()
|
||||||
|
for code in custom_codes.split("\n")
|
||||||
|
if code.strip()
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
index_code = {
|
||||||
|
"沪深300": "000300.SH",
|
||||||
|
"中证500": "000905.SH",
|
||||||
|
"中证1000": "000852.SH"
|
||||||
|
}[pool_type]
|
||||||
|
universe = broker.get_index_stocks(
|
||||||
|
index_code,
|
||||||
|
end_date.strftime("%Y%m%d")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 计算选项
|
||||||
|
st.markdown("##### 计算选项")
|
||||||
|
skip_existing = st.checkbox(
|
||||||
|
"跳过已计算的因子",
|
||||||
|
value=True,
|
||||||
|
help="如果勾选,将跳过数据库中已存在的因子计算结果"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. 异步计算函数
|
||||||
|
def run_factor_calculation_async():
|
||||||
|
"""异步执行因子计算"""
|
||||||
|
# 在异步线程中避免直接访问st.session_state
|
||||||
|
# 使用全局变量或文件来传递进度信息
|
||||||
|
|
||||||
|
# 计算参数
|
||||||
|
total_stocks = len(universe) if universe else len(_get_all_stocks())
|
||||||
|
total_batches = len(selected_factors)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 执行因子计算
|
||||||
|
results = []
|
||||||
|
for i, factor in enumerate(selected_factors):
|
||||||
|
# 计算单个交易日的因子
|
||||||
|
current_date = start_date
|
||||||
|
while current_date <= end_date:
|
||||||
|
try:
|
||||||
|
# 计算指定日期的因子
|
||||||
|
daily_results = compute_factors(
|
||||||
|
current_date,
|
||||||
|
[factor],
|
||||||
|
ts_codes=universe,
|
||||||
|
skip_existing=skip_existing
|
||||||
|
)
|
||||||
|
results.extend(daily_results)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 记录错误但不中断计算
|
||||||
|
error_msg = f"计算因子 {factor.name} 在日期 {current_date} 时出错: {str(e)}"
|
||||||
|
print(f"ERROR: {error_msg}")
|
||||||
|
|
||||||
|
current_date += timedelta(days=1)
|
||||||
|
|
||||||
|
# 短暂暂停
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# 计算完成,通过文件或全局变量传递结果
|
||||||
|
# 这里使用简单的文件方式传递结果
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 创建临时文件存储结果
|
||||||
|
temp_dir = tempfile.gettempdir()
|
||||||
|
result_file = os.path.join(temp_dir, f"factor_calculation_{threading.get_ident()}.json")
|
||||||
|
|
||||||
|
result_data = {
|
||||||
|
'success': True,
|
||||||
|
'results': [r.to_dict() if hasattr(r, 'to_dict') else str(r) for r in results],
|
||||||
|
'factors': [f.name for f in selected_factors],
|
||||||
|
'date_range': f"{start_date} 至 {end_date}",
|
||||||
|
'stock_count': len(set(r.ts_code for r in results)) if results else 0,
|
||||||
|
'message': f"因子计算完成!共计算 {len(results)} 条因子记录"
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(result_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(result_data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 计算失败
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
temp_dir = tempfile.gettempdir()
|
||||||
|
result_file = os.path.join(temp_dir, f"factor_calculation_{threading.get_ident()}.json")
|
||||||
|
|
||||||
|
error_data = {
|
||||||
|
'success': False,
|
||||||
|
'error': str(e),
|
||||||
|
'message': f"因子计算失败: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(result_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(error_data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
# 6. 开始计算按钮
|
||||||
|
if st.button("开始计算因子", disabled=not selected_factors):
|
||||||
|
# 重置状态
|
||||||
|
if 'factor_calculation_thread' in st.session_state:
|
||||||
|
st.session_state.factor_calculation_thread = None
|
||||||
|
if 'factor_calculation_results' in st.session_state:
|
||||||
|
st.session_state.factor_calculation_results = None
|
||||||
|
if 'factor_calculation_error' in st.session_state:
|
||||||
|
st.session_state.factor_calculation_error = None
|
||||||
|
|
||||||
|
# 初始化进度状态
|
||||||
|
total_stocks = len(universe) if universe else len(_get_all_stocks())
|
||||||
|
factor_progress.start_calculation(
|
||||||
|
total_securities=total_stocks,
|
||||||
|
total_batches=len(selected_factors)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 启动异步线程
|
||||||
|
thread = threading.Thread(target=run_factor_calculation_async)
|
||||||
|
thread.daemon = True
|
||||||
|
thread.start()
|
||||||
|
st.session_state.factor_calculation_thread = thread
|
||||||
|
st.session_state.factor_calculation_thread_id = thread.ident
|
||||||
|
|
||||||
|
# 显示计算中状态
|
||||||
|
st.info("因子计算已开始,请查看侧边栏进度显示...")
|
||||||
|
|
||||||
|
# 强制重新运行以显示进度
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# 7. 显示计算结果
|
||||||
|
if 'factor_calculation_results' in st.session_state and st.session_state.factor_calculation_results:
|
||||||
|
results = st.session_state.factor_calculation_results
|
||||||
|
|
||||||
|
st.success("✅ 因子计算完成!")
|
||||||
|
|
||||||
|
col1, col2, col3 = st.columns(3)
|
||||||
|
with col1:
|
||||||
|
st.metric("计算因子数量", len(results['factors']))
|
||||||
|
with col2:
|
||||||
|
st.metric("涉及股票数量", results['stock_count'])
|
||||||
|
with col3:
|
||||||
|
st.metric("计算时间范围", results['date_range'])
|
||||||
|
|
||||||
|
# 显示计算详情
|
||||||
|
with st.expander("查看计算详情"):
|
||||||
|
if results['results']:
|
||||||
|
# 转换为DataFrame显示
|
||||||
|
import pandas as pd
|
||||||
|
df_data = []
|
||||||
|
for result in results['results']:
|
||||||
|
for factor_name, value in result.values.items():
|
||||||
|
df_data.append({
|
||||||
|
'日期': result.trade_date,
|
||||||
|
'股票代码': result.ts_code,
|
||||||
|
'因子名称': factor_name,
|
||||||
|
'因子值': value
|
||||||
|
})
|
||||||
|
|
||||||
|
if df_data:
|
||||||
|
df = pd.DataFrame(df_data)
|
||||||
|
st.dataframe(df.head(100), use_container_width=True) # 只显示前100条
|
||||||
|
st.info(f"共 {len(df_data)} 条因子记录(显示前100条)")
|
||||||
|
else:
|
||||||
|
st.info("没有找到因子计算结果")
|
||||||
|
|
||||||
|
# 8. 检查异步线程结果
|
||||||
|
if 'factor_calculation_thread_id' in st.session_state:
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
thread_id = st.session_state.factor_calculation_thread_id
|
||||||
|
temp_dir = tempfile.gettempdir()
|
||||||
|
result_file = os.path.join(temp_dir, f"factor_calculation_{thread_id}.json")
|
||||||
|
|
||||||
|
# 检查结果文件是否存在
|
||||||
|
if os.path.exists(result_file):
|
||||||
|
try:
|
||||||
|
with open(result_file, 'r', encoding='utf-8') as f:
|
||||||
|
result_data = json.load(f)
|
||||||
|
|
||||||
|
# 处理结果
|
||||||
|
if result_data['success']:
|
||||||
|
# 计算成功
|
||||||
|
factor_progress.complete_calculation(result_data['message'])
|
||||||
|
st.session_state.factor_calculation_results = {
|
||||||
|
'results': result_data['results'],
|
||||||
|
'factors': result_data['factors'],
|
||||||
|
'date_range': result_data['date_range'],
|
||||||
|
'stock_count': result_data['stock_count']
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 计算失败
|
||||||
|
factor_progress.error_occurred(result_data['message'])
|
||||||
|
st.session_state.factor_calculation_error = result_data['error']
|
||||||
|
|
||||||
|
# 清理临时文件
|
||||||
|
os.remove(result_file)
|
||||||
|
|
||||||
|
# 清除线程状态
|
||||||
|
st.session_state.factor_calculation_thread_id = None
|
||||||
|
|
||||||
|
# 强制重新运行以显示结果
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"处理计算结果时出错: {str(e)}")
|
||||||
|
|
||||||
|
# 9. 显示错误信息
|
||||||
|
if 'factor_calculation_error' in st.session_state and st.session_state.factor_calculation_error:
|
||||||
|
st.error(f"❌ 因子计算失败: {st.session_state.factor_calculation_error}")
|
||||||
Loading…
Reference in New Issue
Block a user