refactor factor calculation and improve progress tracking
This commit is contained in:
parent
44adc836fa
commit
43c70f3f7f
@ -151,5 +151,5 @@ Streamlit `自检测试` 页签提供:
|
|||||||
TODO
|
TODO
|
||||||
1. 在选股时,因子都已经提前算好,不需要再计算了,直接用就行。
|
1. 在选股时,因子都已经提前算好,不需要再计算了,直接用就行。
|
||||||
2. 因子计算的公式再确认下
|
2. 因子计算的公式再确认下
|
||||||
3. 审查整个项目的代码逻辑,从main.py开始,逐字逐句检查。如一些重复的检查可以去掉;未实现的功能请标记TODO,并给出实现思路;错误的、低效率的调用请修正;代码结构性的问题请指出。
|
3. 审查整个项目的代码逻辑,从app/ui/streamlit_app.py开始,逐字逐句检查。如一些重复的安全检查可以去掉;明显果实的临时性代码请删除掉;未实现的功能请标记TODO,并给出实现思路;错误的、低效率的调用请修正;代码结构性的问题请指出并尝试修正;复杂不清晰的代码结构请尝试重构;
|
||||||
4. 梳理整个项目的所有业务逻辑。针对每个业务,从业务实现角度评估代码功能是否存在问题,是否需要优化,是否需要重构。
|
4. 梳理整个项目的所有业务逻辑。针对每个业务,从业务实现角度评估代码功能是否存在问题,是否需要优化,是否需要重构。
|
||||||
|
|||||||
@ -9,7 +9,8 @@ from app.features.factors import (
|
|||||||
DEFAULT_FACTORS,
|
DEFAULT_FACTORS,
|
||||||
FactorResult,
|
FactorResult,
|
||||||
FactorSpec,
|
FactorSpec,
|
||||||
compute_factor_range
|
compute_factor_range,
|
||||||
|
lookup_factor_spec,
|
||||||
)
|
)
|
||||||
from app.utils.data_access import DataBroker
|
from app.utils.data_access import DataBroker
|
||||||
from app.utils.logging import get_logger
|
from app.utils.logging import get_logger
|
||||||
@ -90,14 +91,14 @@ def evaluate_factor(
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 计算因子值
|
spec = lookup_factor_spec(factor_name) or FactorSpec(factor_name, 0)
|
||||||
# 设置 skip_existing=False,确保即使因子已存在也会重新计算
|
|
||||||
factor_results = compute_factor_range(
|
factor_results = compute_factor_range(
|
||||||
start_date,
|
start_date,
|
||||||
end_date,
|
end_date,
|
||||||
factors=[FactorSpec(factor_name, 0)],
|
factors=[spec],
|
||||||
ts_codes=universe,
|
ts_codes=universe,
|
||||||
skip_existing=False
|
skip_existing=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 因子计算完成(在异步线程中不直接访问factor_progress)
|
# 因子计算完成(在异步线程中不直接访问factor_progress)
|
||||||
|
|||||||
@ -90,6 +90,17 @@ DEFAULT_FACTORS: List[FactorSpec] = [
|
|||||||
FactorSpec("risk_penalty", 0), # 风险惩罚因子
|
FactorSpec("risk_penalty", 0), # 风险惩罚因子
|
||||||
]
|
]
|
||||||
|
|
||||||
|
_FACTOR_SPEC_MAP: Dict[str, FactorSpec] = {spec.name: spec for spec in DEFAULT_FACTORS}
|
||||||
|
|
||||||
|
|
||||||
|
def lookup_factor_spec(name: str) -> Optional[FactorSpec]:
|
||||||
|
"""Return a copy of the registered ``FactorSpec`` for ``name`` if available."""
|
||||||
|
|
||||||
|
base = _FACTOR_SPEC_MAP.get(name)
|
||||||
|
if base is None:
|
||||||
|
return None
|
||||||
|
return FactorSpec(name=base.name, window=base.window)
|
||||||
|
|
||||||
|
|
||||||
def compute_factors(
|
def compute_factors(
|
||||||
trade_date: date,
|
trade_date: date,
|
||||||
@ -304,30 +315,33 @@ def _existing_factor_codes_with_factors(trade_date: str, factor_names: List[str]
|
|||||||
if not factor_names:
|
if not factor_names:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# 构建检查条件
|
valid_names = [
|
||||||
conditions = []
|
name
|
||||||
for name in factor_names:
|
for name in factor_names
|
||||||
conditions.append(f"json_extract(factors, '$.{name}') IS NOT NULL")
|
if isinstance(name, str) and _IDENTIFIER_RE.match(name)
|
||||||
condition_str = " AND ".join(conditions)
|
]
|
||||||
|
if not valid_names:
|
||||||
# 构建SQL查询
|
return {}
|
||||||
query = """
|
|
||||||
SELECT ts_code
|
|
||||||
FROM factors
|
|
||||||
WHERE trade_date = ?
|
|
||||||
AND """ + condition_str + """
|
|
||||||
GROUP BY ts_code
|
|
||||||
"""
|
|
||||||
|
|
||||||
with db_session(read_only=True) as conn:
|
with db_session(read_only=True) as conn:
|
||||||
|
columns = {
|
||||||
|
row["name"]
|
||||||
|
for row in conn.execute("PRAGMA table_info(factors)").fetchall()
|
||||||
|
}
|
||||||
|
selected = [name for name in valid_names if name in columns]
|
||||||
|
if not selected:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
predicates = " AND ".join(f"{col} IS NOT NULL" for col in selected)
|
||||||
|
query = (
|
||||||
|
"SELECT ts_code FROM factors "
|
||||||
|
"WHERE trade_date = ? AND "
|
||||||
|
f"{predicates} "
|
||||||
|
"GROUP BY ts_code"
|
||||||
|
)
|
||||||
rows = conn.execute(query, (trade_date,)).fetchall()
|
rows = conn.execute(query, (trade_date,)).fetchall()
|
||||||
|
|
||||||
# 返回结果
|
return {row["ts_code"]: True for row in rows if row and row["ts_code"]}
|
||||||
result = {}
|
|
||||||
for row in rows:
|
|
||||||
result[row["ts_code"]] = True
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _list_trade_dates(
|
def _list_trade_dates(
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
import time
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
|
|
||||||
@ -25,7 +26,7 @@ class FactorProgressState:
|
|||||||
'status': 'idle', # idle, running, completed, error
|
'status': 'idle', # idle, running, completed, error
|
||||||
'message': '',
|
'message': '',
|
||||||
'start_time': None,
|
'start_time': None,
|
||||||
'elapsed_time': 0
|
'elapsed_time': 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
def start_calculation(self, total_securities: int, total_batches: int) -> None:
|
def start_calculation(self, total_securities: int, total_batches: int) -> None:
|
||||||
@ -35,16 +36,17 @@ class FactorProgressState:
|
|||||||
total_securities: 总证券数量
|
total_securities: 总证券数量
|
||||||
total_batches: 总批次数
|
total_batches: 总批次数
|
||||||
"""
|
"""
|
||||||
|
now = time.time()
|
||||||
st.session_state.factor_progress.update({
|
st.session_state.factor_progress.update({
|
||||||
'current': 0,
|
'current': 0,
|
||||||
'total': total_securities,
|
'total': max(total_securities, 0),
|
||||||
'percentage': 0.0,
|
'percentage': 0.0,
|
||||||
'current_batch': 0,
|
'current_batch': 0,
|
||||||
'total_batches': total_batches,
|
'total_batches': max(total_batches, 0),
|
||||||
'status': 'running',
|
'status': 'running',
|
||||||
'message': '开始因子计算...',
|
'message': '开始因子计算...',
|
||||||
'start_time': st.session_state.get('factor_progress', {}).get('start_time'),
|
'start_time': now,
|
||||||
'elapsed_time': 0
|
'elapsed_time': 0.0,
|
||||||
})
|
})
|
||||||
|
|
||||||
def update_progress(self, current_securities: int, current_batch: int,
|
def update_progress(self, current_securities: int, current_batch: int,
|
||||||
@ -59,18 +61,26 @@ class FactorProgressState:
|
|||||||
progress = st.session_state.factor_progress
|
progress = st.session_state.factor_progress
|
||||||
|
|
||||||
# 计算百分比
|
# 计算百分比
|
||||||
if progress['total'] > 0:
|
total = progress.get('total', 0) or 0
|
||||||
percentage = (current_securities / progress['total']) * 100
|
if total > 0:
|
||||||
|
percentage = (current_securities / total) * 100
|
||||||
else:
|
else:
|
||||||
percentage = 0.0
|
percentage = 0.0
|
||||||
|
|
||||||
|
start_time = progress.get('start_time')
|
||||||
|
if isinstance(start_time, (int, float)):
|
||||||
|
elapsed = max(0.0, time.time() - start_time)
|
||||||
|
else:
|
||||||
|
elapsed = 0.0
|
||||||
|
|
||||||
# 更新状态
|
# 更新状态
|
||||||
progress.update({
|
progress.update({
|
||||||
'current': current_securities,
|
'current': current_securities,
|
||||||
'current_batch': current_batch,
|
'current_batch': current_batch,
|
||||||
'percentage': percentage,
|
'percentage': percentage,
|
||||||
'message': message or f'处理批次 {current_batch}/{progress["total_batches"]}',
|
'message': message or f'处理批次 {current_batch}/{progress["total_batches"] or 1}',
|
||||||
'status': 'running'
|
'status': 'running',
|
||||||
|
'elapsed_time': elapsed,
|
||||||
})
|
})
|
||||||
|
|
||||||
def complete_calculation(self, message: str = '因子计算完成') -> None:
|
def complete_calculation(self, message: str = '因子计算完成') -> None:
|
||||||
@ -80,11 +90,17 @@ class FactorProgressState:
|
|||||||
message: 完成消息
|
message: 完成消息
|
||||||
"""
|
"""
|
||||||
progress = st.session_state.factor_progress
|
progress = st.session_state.factor_progress
|
||||||
|
start_time = progress.get('start_time')
|
||||||
|
if isinstance(start_time, (int, float)):
|
||||||
|
elapsed = max(0.0, time.time() - start_time)
|
||||||
|
else:
|
||||||
|
elapsed = progress.get('elapsed_time', 0.0) or 0.0
|
||||||
progress.update({
|
progress.update({
|
||||||
'current': progress['total'],
|
'current': progress.get('total', 0),
|
||||||
'percentage': 100.0,
|
'percentage': 100.0 if progress.get('total', 0) else progress.get('percentage', 0.0),
|
||||||
'status': 'completed',
|
'status': 'completed',
|
||||||
'message': message
|
'message': message,
|
||||||
|
'elapsed_time': elapsed,
|
||||||
})
|
})
|
||||||
|
|
||||||
def error_occurred(self, error_message: str) -> None:
|
def error_occurred(self, error_message: str) -> None:
|
||||||
@ -93,9 +109,16 @@ class FactorProgressState:
|
|||||||
Args:
|
Args:
|
||||||
error_message: 错误消息
|
error_message: 错误消息
|
||||||
"""
|
"""
|
||||||
st.session_state.factor_progress.update({
|
progress = st.session_state.factor_progress
|
||||||
|
start_time = progress.get('start_time')
|
||||||
|
if isinstance(start_time, (int, float)):
|
||||||
|
elapsed = max(0.0, time.time() - start_time)
|
||||||
|
else:
|
||||||
|
elapsed = progress.get('elapsed_time', 0.0) or 0.0
|
||||||
|
progress.update({
|
||||||
'status': 'error',
|
'status': 'error',
|
||||||
'message': f'错误: {error_message}'
|
'message': f'错误: {error_message}',
|
||||||
|
'elapsed_time': elapsed,
|
||||||
})
|
})
|
||||||
|
|
||||||
def get_progress_info(self) -> Dict[str, Any]:
|
def get_progress_info(self) -> Dict[str, Any]:
|
||||||
@ -118,7 +141,7 @@ class FactorProgressState:
|
|||||||
'status': 'idle',
|
'status': 'idle',
|
||||||
'message': '',
|
'message': '',
|
||||||
'start_time': None,
|
'start_time': None,
|
||||||
'elapsed_time': 0
|
'elapsed_time': 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -199,4 +222,4 @@ def is_factor_calculation_running() -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
是否正在进行因子计算
|
是否正在进行因子计算
|
||||||
"""
|
"""
|
||||||
return factor_progress.get_progress_info()['status'] == 'running'
|
return factor_progress.get_progress_info()['status'] == 'running'
|
||||||
|
|||||||
@ -2,16 +2,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
from dataclasses import asdict
|
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import plotly.express as px
|
import plotly.express as px
|
||||||
import requests
|
|
||||||
from requests.exceptions import RequestException
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from app.agents.base import AgentContext
|
from app.agents.base import AgentContext
|
||||||
|
|||||||
@ -1,19 +1,20 @@
|
|||||||
"""因子计算页面。"""
|
"""因子计算页面。"""
|
||||||
from datetime import datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Sequence
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
from app.features.factors import compute_factors, DEFAULT_FACTORS, FactorSpec
|
from app.features.factors import DEFAULT_FACTORS, FactorSpec, compute_factor_range
|
||||||
from app.ui.progress_state import factor_progress
|
from app.ui.progress_state import factor_progress
|
||||||
|
from app.ui.shared import LOGGER, LOG_EXTRA
|
||||||
from app.utils.data_access import DataBroker
|
from app.utils.data_access import DataBroker
|
||||||
from app.utils.db import db_session
|
from app.utils.db import db_session
|
||||||
|
|
||||||
|
|
||||||
def _get_latest_trading_date() -> datetime.date:
|
def _get_latest_trading_date() -> datetime.date:
|
||||||
"""获取数据库中的最新交易日期"""
|
"""获取数据库中的最新交易日期"""
|
||||||
with db_session() as session:
|
with db_session(read_only=True) as conn:
|
||||||
result = session.execute(
|
result = conn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT trade_date
|
SELECT trade_date
|
||||||
FROM daily_basic
|
FROM daily_basic
|
||||||
@ -34,9 +35,9 @@ def _get_all_stocks() -> List[str]:
|
|||||||
"""获取所有股票代码"""
|
"""获取所有股票代码"""
|
||||||
try:
|
try:
|
||||||
# 从daily表获取所有股票代码
|
# 从daily表获取所有股票代码
|
||||||
with db_session() as session:
|
with db_session(read_only=True) as conn:
|
||||||
latest_date = _get_latest_trading_date()
|
latest_date = _get_latest_trading_date()
|
||||||
result = session.execute(
|
result = conn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT DISTINCT ts_code
|
SELECT DISTINCT ts_code
|
||||||
FROM daily
|
FROM daily
|
||||||
@ -45,12 +46,88 @@ def _get_all_stocks() -> List[str]:
|
|||||||
{"trade_date": latest_date.strftime("%Y%m%d")}
|
{"trade_date": latest_date.strftime("%Y%m%d")}
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
return [row[0] for row in result] if result else []
|
return [row["ts_code"] for row in result if row and row["ts_code"]] if result else []
|
||||||
except Exception as e:
|
except Exception as exc:
|
||||||
st.error(f"获取股票列表失败: {str(e)}")
|
LOGGER.exception("获取股票列表失败", extra={**LOG_EXTRA, "error": str(exc)})
|
||||||
|
st.error(f"获取股票列表失败: {exc}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_universe(universe: Optional[Sequence[str]]) -> List[str]:
|
||||||
|
"""去重并规范股票代码格式。"""
|
||||||
|
if not universe:
|
||||||
|
return []
|
||||||
|
seen: dict[str, None] = {}
|
||||||
|
for code in universe:
|
||||||
|
normalized = code.strip().upper()
|
||||||
|
if normalized and normalized not in seen:
|
||||||
|
seen[normalized] = None
|
||||||
|
return list(seen.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def _get_trade_dates_between(
|
||||||
|
start: date,
|
||||||
|
end: date,
|
||||||
|
universe: Optional[Sequence[str]] = None,
|
||||||
|
) -> List[date]:
|
||||||
|
"""获取区间内存在行情数据的交易日期列表。"""
|
||||||
|
|
||||||
|
if end < start:
|
||||||
|
return []
|
||||||
|
|
||||||
|
start_str = start.strftime("%Y%m%d")
|
||||||
|
end_str = end.strftime("%Y%m%d")
|
||||||
|
params: List[str] = [start_str, end_str]
|
||||||
|
query = (
|
||||||
|
"SELECT DISTINCT trade_date FROM daily "
|
||||||
|
"WHERE trade_date BETWEEN ? AND ?"
|
||||||
|
)
|
||||||
|
scoped_universe = _normalize_universe(universe)
|
||||||
|
if scoped_universe:
|
||||||
|
placeholders = ", ".join("?" for _ in scoped_universe)
|
||||||
|
query += f" AND ts_code IN ({placeholders})"
|
||||||
|
params.extend(scoped_universe)
|
||||||
|
query += " ORDER BY trade_date"
|
||||||
|
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
rows = conn.execute(query, params).fetchall()
|
||||||
|
|
||||||
|
return [
|
||||||
|
datetime.strptime(str(row["trade_date"]), "%Y%m%d").date()
|
||||||
|
for row in rows
|
||||||
|
if row and row["trade_date"]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_total_workload(
|
||||||
|
trade_dates: Sequence[date],
|
||||||
|
universe: Optional[Sequence[str]],
|
||||||
|
) -> int:
|
||||||
|
"""估算本次计算需要处理的证券数量,用于驱动进度条。"""
|
||||||
|
|
||||||
|
trade_days = list(trade_dates)
|
||||||
|
if not trade_days:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
scoped_universe = _normalize_universe(universe)
|
||||||
|
if scoped_universe:
|
||||||
|
return len(scoped_universe) * len(trade_days)
|
||||||
|
|
||||||
|
start_str = min(trade_days).strftime("%Y%m%d")
|
||||||
|
end_str = max(trade_days).strftime("%Y%m%d")
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT COUNT(DISTINCT ts_code) AS cnt
|
||||||
|
FROM daily
|
||||||
|
WHERE trade_date BETWEEN ? AND ?
|
||||||
|
""",
|
||||||
|
(start_str, end_str),
|
||||||
|
).fetchone()
|
||||||
|
universe_size = int(row["cnt"]) if row and row["cnt"] is not None else 0
|
||||||
|
return universe_size * len(trade_days)
|
||||||
|
|
||||||
|
|
||||||
def render_factor_calculation() -> None:
|
def render_factor_calculation() -> None:
|
||||||
"""渲染因子计算页面。"""
|
"""渲染因子计算页面。"""
|
||||||
st.subheader("📊 因子计算")
|
st.subheader("📊 因子计算")
|
||||||
@ -153,97 +230,55 @@ def render_factor_calculation() -> None:
|
|||||||
help="如果勾选,将跳过数据库中已存在的因子计算结果"
|
help="如果勾选,将跳过数据库中已存在的因子计算结果"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. 同步计算函数
|
# 5. 开始计算按钮
|
||||||
def run_factor_calculation_sync():
|
|
||||||
"""同步执行因子计算"""
|
|
||||||
# 计算参数
|
|
||||||
total_stocks = len(universe) if universe else len(_get_all_stocks())
|
|
||||||
total_batches = len(selected_factors)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 执行因子计算
|
|
||||||
results = []
|
|
||||||
for i, factor in enumerate(selected_factors):
|
|
||||||
# 更新批次进度
|
|
||||||
factor_progress.update_progress(
|
|
||||||
current_securities=0,
|
|
||||||
current_batch=i+1,
|
|
||||||
message=f"正在计算因子: {factor.name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 计算单个交易日的因子
|
|
||||||
current_date = start_date
|
|
||||||
while current_date <= end_date:
|
|
||||||
try:
|
|
||||||
# 计算指定日期的因子
|
|
||||||
daily_results = compute_factors(
|
|
||||||
current_date,
|
|
||||||
[factor],
|
|
||||||
ts_codes=universe,
|
|
||||||
skip_existing=skip_existing
|
|
||||||
)
|
|
||||||
results.extend(daily_results)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# 记录错误但不中断计算
|
|
||||||
error_msg = f"计算因子 {factor.name} 在日期 {current_date} 时出错: {str(e)}"
|
|
||||||
print(f"ERROR: {error_msg}")
|
|
||||||
|
|
||||||
current_date += timedelta(days=1)
|
|
||||||
|
|
||||||
# 计算完成
|
|
||||||
factor_progress.complete_calculation(f"因子计算完成!共计算 {len(results)} 条因子记录")
|
|
||||||
|
|
||||||
return {
|
|
||||||
'success': True,
|
|
||||||
'results': results,
|
|
||||||
'factors': [f.name for f in selected_factors],
|
|
||||||
'date_range': f"{start_date} 至 {end_date}",
|
|
||||||
'stock_count': len(set(r.ts_code for r in results)) if results else 0,
|
|
||||||
'message': f"因子计算完成!共计算 {len(results)} 条因子记录"
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# 计算失败
|
|
||||||
factor_progress.error_occurred(f"因子计算失败: {str(e)}")
|
|
||||||
return {
|
|
||||||
'success': False,
|
|
||||||
'error': str(e),
|
|
||||||
'message': f"因子计算失败: {str(e)}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 6. 开始计算按钮
|
|
||||||
if st.button("开始计算因子", disabled=not selected_factors):
|
if st.button("开始计算因子", disabled=not selected_factors):
|
||||||
# 重置状态
|
# 重置状态
|
||||||
if 'factor_calculation_results' in st.session_state:
|
st.session_state.pop('factor_calculation_results', None)
|
||||||
st.session_state.factor_calculation_results = None
|
st.session_state.pop('factor_calculation_error', None)
|
||||||
if 'factor_calculation_error' in st.session_state:
|
factor_progress.reset()
|
||||||
st.session_state.factor_calculation_error = None
|
|
||||||
|
scoped_universe = _normalize_universe(universe) or None
|
||||||
# 初始化进度状态
|
trade_dates = _get_trade_dates_between(start_date, end_date, scoped_universe)
|
||||||
total_stocks = len(universe) if universe else len(_get_all_stocks())
|
if not trade_dates:
|
||||||
|
st.warning("所选时间窗口内无可用交易日数据,请先执行数据同步。")
|
||||||
|
return
|
||||||
|
|
||||||
|
total_workload = _estimate_total_workload(trade_dates, scoped_universe)
|
||||||
factor_progress.start_calculation(
|
factor_progress.start_calculation(
|
||||||
total_securities=total_stocks,
|
total_securities=max(total_workload, 1),
|
||||||
total_batches=len(selected_factors)
|
total_batches=len(trade_dates),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 直接调用同步计算函数
|
with st.spinner("正在计算因子..."):
|
||||||
result = run_factor_calculation_sync()
|
try:
|
||||||
|
results = compute_factor_range(
|
||||||
# 处理计算结果
|
start=min(trade_dates),
|
||||||
if result['success']:
|
end=max(trade_dates),
|
||||||
st.session_state.factor_calculation_results = {
|
factors=selected_factors,
|
||||||
'results': result['results'],
|
ts_codes=scoped_universe,
|
||||||
'factors': result['factors'],
|
skip_existing=skip_existing,
|
||||||
'date_range': result['date_range'],
|
)
|
||||||
'stock_count': result['stock_count']
|
except Exception as exc:
|
||||||
}
|
LOGGER.exception("因子计算失败", extra={**LOG_EXTRA, "error": str(exc)})
|
||||||
st.success("✅ 因子计算完成!")
|
factor_progress.error_occurred(str(exc))
|
||||||
else:
|
st.session_state.factor_calculation_error = str(exc)
|
||||||
st.session_state.factor_calculation_error = result['error']
|
st.error(f"❌ 因子计算失败: {exc}")
|
||||||
st.error(f"❌ 因子计算失败: {result['error']}")
|
else:
|
||||||
|
factor_progress.complete_calculation(
|
||||||
|
f"因子计算完成,共生成 {len(results)} 条因子记录"
|
||||||
|
)
|
||||||
|
factor_names = [spec.name for spec in selected_factors]
|
||||||
|
stock_count = len({item.ts_code for item in results}) if results else 0
|
||||||
|
st.session_state.factor_calculation_results = {
|
||||||
|
'results': results,
|
||||||
|
'factors': factor_names,
|
||||||
|
'date_range': f"{trade_dates[0]} 至 {trade_dates[-1]}",
|
||||||
|
'stock_count': stock_count,
|
||||||
|
'trade_days': len(trade_dates),
|
||||||
|
}
|
||||||
|
st.success("✅ 因子计算完成!")
|
||||||
|
|
||||||
# 7. 显示计算结果
|
# 6. 显示计算结果
|
||||||
if 'factor_calculation_results' in st.session_state and st.session_state.factor_calculation_results:
|
if 'factor_calculation_results' in st.session_state and st.session_state.factor_calculation_results:
|
||||||
results = st.session_state.factor_calculation_results
|
results = st.session_state.factor_calculation_results
|
||||||
|
|
||||||
@ -255,7 +290,8 @@ def render_factor_calculation() -> None:
|
|||||||
with col2:
|
with col2:
|
||||||
st.metric("涉及股票数量", results['stock_count'])
|
st.metric("涉及股票数量", results['stock_count'])
|
||||||
with col3:
|
with col3:
|
||||||
st.metric("计算时间范围", results['date_range'])
|
st.metric("交易日数量", results.get('trade_days', 0))
|
||||||
|
st.caption(f"时间范围:{results['date_range']}")
|
||||||
|
|
||||||
# 显示计算详情
|
# 显示计算详情
|
||||||
with st.expander("查看计算详情"):
|
with st.expander("查看计算详情"):
|
||||||
@ -279,8 +315,6 @@ def render_factor_calculation() -> None:
|
|||||||
else:
|
else:
|
||||||
st.info("没有找到因子计算结果")
|
st.info("没有找到因子计算结果")
|
||||||
|
|
||||||
# 8. 移除异步线程检查逻辑(已改为同步模式)
|
# 7. 显示错误信息
|
||||||
|
|
||||||
# 9. 显示错误信息
|
|
||||||
if 'factor_calculation_error' in st.session_state and st.session_state.factor_calculation_error:
|
if 'factor_calculation_error' in st.session_state and st.session_state.factor_calculation_error:
|
||||||
st.error(f"❌ 因子计算失败: {st.session_state.factor_calculation_error}")
|
st.error(f"❌ 因子计算失败: {st.session_state.factor_calculation_error}")
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
"""股票筛选与评估视图。"""
|
"""股票筛选与评估视图。"""
|
||||||
from datetime import datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional
|
||||||
|
import json
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -15,11 +16,10 @@ from app.utils.db import db_session
|
|||||||
from app.utils.logging import get_logger
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
def _get_latest_trading_date() -> datetime.date:
|
def _get_latest_trading_date() -> date:
|
||||||
"""获取数据库中的最新交易日期"""
|
"""获取数据库中的最新交易日期"""
|
||||||
with db_session() as session:
|
with db_session(read_only=True) as conn:
|
||||||
# 获取当前日期的上一个有效交易日
|
result = conn.execute(
|
||||||
result = session.execute(
|
|
||||||
"""
|
"""
|
||||||
SELECT trade_date
|
SELECT trade_date
|
||||||
FROM daily_basic
|
FROM daily_basic
|
||||||
@ -35,6 +35,19 @@ def _get_latest_trading_date() -> datetime.date:
|
|||||||
return datetime.strptime(str(result[0]), "%Y%m%d").date()
|
return datetime.strptime(str(result[0]), "%Y%m%d").date()
|
||||||
return datetime.now().date() - timedelta(days=1) # 如果查询失败才返回昨天
|
return datetime.now().date() - timedelta(days=1) # 如果查询失败才返回昨天
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_universe(universe: Optional[List[str]]) -> List[str]:
|
||||||
|
"""标准化股票代码列表,去重并转为大写。"""
|
||||||
|
|
||||||
|
if not universe:
|
||||||
|
return []
|
||||||
|
normalized: Dict[str, None] = {}
|
||||||
|
for code in universe:
|
||||||
|
candidate = (code or "").strip().upper()
|
||||||
|
if candidate and candidate not in normalized:
|
||||||
|
normalized[candidate] = None
|
||||||
|
return list(normalized.keys())
|
||||||
|
|
||||||
def render_stock_evaluation() -> None:
|
def render_stock_evaluation() -> None:
|
||||||
"""渲染股票筛选与评估页面。"""
|
"""渲染股票筛选与评估页面。"""
|
||||||
LOGGER = get_logger(__name__)
|
LOGGER = get_logger(__name__)
|
||||||
@ -141,6 +154,9 @@ def render_stock_evaluation() -> None:
|
|||||||
index_code,
|
index_code,
|
||||||
end_date.strftime("%Y%m%d")
|
end_date.strftime("%Y%m%d")
|
||||||
)
|
)
|
||||||
|
universe = _normalize_universe(universe)
|
||||||
|
if universe == []:
|
||||||
|
universe = None
|
||||||
|
|
||||||
# 4. 评估结果
|
# 4. 评估结果
|
||||||
|
|
||||||
@ -167,11 +183,12 @@ def render_stock_evaluation() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
st.session_state.evaluation_status = 'running'
|
st.session_state.evaluation_status = 'running'
|
||||||
|
st.session_state.pop('evaluation_error', None)
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for i, factor_name in enumerate(selected_factors):
|
for i, factor_name in enumerate(selected_factors):
|
||||||
st.session_state.current_factor = factor_name
|
st.session_state.current_factor = factor_name
|
||||||
st.session_state.progress = (i / len(selected_factors)) * 100
|
st.session_state.progress = ((i + 1) / len(selected_factors)) * 100
|
||||||
|
|
||||||
performance = evaluate_factor(
|
performance = evaluate_factor(
|
||||||
factor_name,
|
factor_name,
|
||||||
@ -181,11 +198,11 @@ def render_stock_evaluation() -> None:
|
|||||||
)
|
)
|
||||||
results.append({
|
results.append({
|
||||||
"因子": factor_name,
|
"因子": factor_name,
|
||||||
"IC均值": f"{performance.ic_mean:.4f}",
|
"IC均值": performance.ic_mean,
|
||||||
"RankIC均值": f"{performance.rank_ic_mean:.4f}",
|
"RankIC均值": performance.rank_ic_mean,
|
||||||
"IC信息比率": f"{performance.ic_ir:.4f}",
|
"IC信息比率": performance.ic_ir,
|
||||||
"夏普比率": f"{performance.sharpe_ratio:.4f}" if performance.sharpe_ratio else "N/A",
|
"夏普比率": performance.sharpe_ratio,
|
||||||
"换手率": f"{performance.turnover_rate*100:.1f}%" if performance.turnover_rate else "N/A"
|
"换手率": performance.turnover_rate,
|
||||||
})
|
})
|
||||||
|
|
||||||
st.session_state.evaluation_results = results
|
st.session_state.evaluation_results = results
|
||||||
@ -221,71 +238,89 @@ def render_stock_evaluation() -> None:
|
|||||||
|
|
||||||
st.markdown("##### 因子评估结果")
|
st.markdown("##### 因子评估结果")
|
||||||
result_df = pd.DataFrame(results)
|
result_df = pd.DataFrame(results)
|
||||||
st.dataframe(
|
if not result_df.empty:
|
||||||
result_df,
|
display_df = result_df.copy()
|
||||||
hide_index=True,
|
for col in ["IC均值", "RankIC均值", "IC信息比率"]:
|
||||||
width="stretch"
|
if col in display_df:
|
||||||
)
|
display_df[col] = display_df[col].map(lambda v: f"{v:.4f}")
|
||||||
|
if "夏普比率" in display_df:
|
||||||
|
display_df["夏普比率"] = display_df["夏普比率"].map(
|
||||||
|
lambda v: "N/A" if v is None else f"{v:.4f}"
|
||||||
|
)
|
||||||
|
if "换手率" in display_df:
|
||||||
|
display_df["换手率"] = display_df["换手率"].map(
|
||||||
|
lambda v: "N/A" if v is None else f"{v * 100:.1f}%"
|
||||||
|
)
|
||||||
|
st.dataframe(
|
||||||
|
display_df,
|
||||||
|
hide_index=True,
|
||||||
|
width="stretch"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
st.info("未产生任何因子评估结果。")
|
||||||
|
|
||||||
# 绘制IC均值分布
|
# 绘制IC均值分布
|
||||||
ic_means = [float(r["IC均值"]) for r in results]
|
ic_means = result_df["IC均值"].astype(float).tolist() if not result_df.empty else []
|
||||||
chart_df = pd.DataFrame({
|
chart_df = pd.DataFrame({
|
||||||
"因子": [r["因子"] for r in results],
|
"因子": [r["因子"] for r in results],
|
||||||
"IC均值": ic_means
|
"IC均值": ic_means
|
||||||
})
|
})
|
||||||
st.bar_chart(chart_df.set_index("因子"))
|
st.bar_chart(chart_df.set_index("因子"))
|
||||||
|
|
||||||
# 生成股票评分
|
if not ic_means:
|
||||||
|
st.info("暂无足够的 IC 数据,无法生成股票评分。")
|
||||||
|
return
|
||||||
|
|
||||||
with st.spinner("正在生成股票评分..."):
|
with st.spinner("正在生成股票评分..."):
|
||||||
# 使用IC均值作为权重,但如果IC均值全为零,则使用均匀分布
|
|
||||||
if all(mean == 0 for mean in ic_means):
|
if all(mean == 0 for mean in ic_means):
|
||||||
factor_weights = [1.0 / len(ic_means)] * len(ic_means)
|
factor_weights = [1.0 / len(ic_means)] * len(ic_means)
|
||||||
LOGGER.info("所有因子IC均值均为零,使用均匀权重", extra=LOG_EXTRA)
|
LOGGER.info("所有因子IC均值均为零,使用均匀权重", extra=LOG_EXTRA)
|
||||||
else:
|
else:
|
||||||
# 将IC均值归一化为权重
|
abs_sum = sum(abs(m) for m in ic_means) or 1.0
|
||||||
abs_sum = sum(abs(m) for m in ic_means)
|
|
||||||
factor_weights = [m / abs_sum for m in ic_means]
|
factor_weights = [m / abs_sum for m in ic_means]
|
||||||
LOGGER.info("使用IC均值作为权重: %s", factor_weights, extra=LOG_EXTRA)
|
LOGGER.info("使用IC均值作为权重: %s", factor_weights, extra=LOG_EXTRA)
|
||||||
|
|
||||||
scores = _calculate_stock_scores(
|
scores = _calculate_stock_scores(
|
||||||
universe,
|
universe,
|
||||||
selected_factors,
|
selected_factors,
|
||||||
end_date,
|
end_date,
|
||||||
factor_weights
|
factor_weights
|
||||||
)
|
)
|
||||||
|
|
||||||
if scores:
|
if scores:
|
||||||
st.markdown("##### 股票综合评分 (Top 20)")
|
st.markdown("##### 股票综合评分 (Top 20)")
|
||||||
score_df = pd.DataFrame(scores).sort_values(
|
score_df = pd.DataFrame(scores).sort_values(
|
||||||
"综合评分",
|
"综合评分",
|
||||||
ascending=False
|
ascending=False
|
||||||
).head(20)
|
)
|
||||||
st.dataframe(
|
top_df = score_df.head(20).reset_index(drop=True)
|
||||||
score_df,
|
display_scores = top_df.copy()
|
||||||
hide_index=True,
|
display_scores["综合评分"] = display_scores["综合评分"].map(lambda v: f"{v:.4f}")
|
||||||
width="stretch"
|
st.dataframe(
|
||||||
)
|
display_scores,
|
||||||
|
hide_index=True,
|
||||||
# 添加入池功能
|
width="stretch"
|
||||||
if st.button("将Top 20股票加入股票池"):
|
)
|
||||||
_add_to_stock_pool(
|
|
||||||
score_df["股票代码"].tolist(),
|
if st.button("将Top 20股票加入股票池"):
|
||||||
end_date
|
_add_to_stock_pool(top_df, end_date)
|
||||||
)
|
st.success("已成功将选中股票加入股票池!")
|
||||||
st.success("已成功将选中股票加入股票池!")
|
else:
|
||||||
|
st.info("无法根据当前因子权重生成有效的股票评分结果。")
|
||||||
|
|
||||||
|
|
||||||
def _calculate_stock_scores(
|
def _calculate_stock_scores(
|
||||||
universe: Optional[List[str]],
|
universe: Optional[List[str]],
|
||||||
factors: List[str],
|
factors: List[str],
|
||||||
eval_date: datetime.date,
|
eval_date: date,
|
||||||
factor_weights: List[float]
|
factor_weights: List[float]
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, object]]:
|
||||||
"""计算股票的综合评分。"""
|
"""计算股票的综合评分。"""
|
||||||
LOGGER = get_logger(__name__)
|
LOGGER = get_logger(__name__)
|
||||||
LOG_EXTRA = {"stage": "stock_evaluation"}
|
LOG_EXTRA = {"stage": "stock_evaluation"}
|
||||||
|
|
||||||
broker = DataBroker()
|
broker = DataBroker()
|
||||||
|
trade_date_str = eval_date.strftime("%Y%m%d")
|
||||||
|
|
||||||
# 记录评估开始
|
# 记录评估开始
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
@ -297,7 +332,7 @@ def _calculate_stock_scores(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 标准化权重
|
# 标准化权重
|
||||||
weights = np.array(factor_weights)
|
weights = np.array(factor_weights, dtype=float)
|
||||||
abs_sum = np.sum(np.abs(weights))
|
abs_sum = np.sum(np.abs(weights))
|
||||||
if abs_sum > 0: # 避免除以零
|
if abs_sum > 0: # 避免除以零
|
||||||
weights = weights / abs_sum
|
weights = weights / abs_sum
|
||||||
@ -306,7 +341,10 @@ def _calculate_stock_scores(
|
|||||||
weights = np.ones_like(weights) / len(weights)
|
weights = np.ones_like(weights) / len(weights)
|
||||||
|
|
||||||
# 获取所有股票的因子值
|
# 获取所有股票的因子值
|
||||||
stocks = universe or broker.get_all_stocks(eval_date.strftime("%Y%m%d"))
|
stocks = universe or broker.get_all_stocks(trade_date_str)
|
||||||
|
if not stocks:
|
||||||
|
LOGGER.warning("股票列表为空,无法生成评分", extra=LOG_EXTRA)
|
||||||
|
return []
|
||||||
|
|
||||||
# 记录股票列表信息
|
# 记录股票列表信息
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
@ -320,42 +358,54 @@ def _calculate_stock_scores(
|
|||||||
|
|
||||||
evaluated_count = 0
|
evaluated_count = 0
|
||||||
skipped_count = 0
|
skipped_count = 0
|
||||||
|
factor_fields = [f"factors.{name}" for name in factors]
|
||||||
|
|
||||||
for ts_code in stocks:
|
for ts_code in stocks:
|
||||||
# 检查数据是否充分
|
if not check_data_sufficiency(ts_code, trade_date_str):
|
||||||
if not check_data_sufficiency(ts_code, eval_date.strftime("%Y%m%d")):
|
|
||||||
skipped_count += 1
|
skipped_count += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 获取股票信息
|
latest_payload = broker.fetch_latest(
|
||||||
info = broker.get_stock_info(ts_code)
|
ts_code,
|
||||||
|
trade_date_str,
|
||||||
|
factor_fields,
|
||||||
|
auto_refresh=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not latest_payload:
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
factor_values: List[float] = []
|
||||||
|
missing = False
|
||||||
|
for field in factor_fields:
|
||||||
|
value = latest_payload.get(field)
|
||||||
|
if value is None:
|
||||||
|
missing = True
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
factor_values.append(float(value))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
missing = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if missing or len(factor_values) != len(factors):
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
info = broker.get_stock_info(ts_code, trade_date_str)
|
||||||
if not info:
|
if not info:
|
||||||
skipped_count += 1
|
skipped_count += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 获取因子值
|
score = float(np.dot(factor_values, weights))
|
||||||
factor_values = []
|
|
||||||
for factor in factors:
|
|
||||||
value = broker.fetch_latest_factor(ts_code, factor, eval_date)
|
|
||||||
if value is None:
|
|
||||||
skipped_count += 1
|
|
||||||
break
|
|
||||||
factor_values.append(value)
|
|
||||||
|
|
||||||
# 检查是否所有因子值都已获取
|
|
||||||
if len(factor_values) != len(factors):
|
|
||||||
skipped_count += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 计算综合评分
|
|
||||||
score = np.dot(factor_values, weights)
|
|
||||||
evaluated_count += 1
|
evaluated_count += 1
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"股票代码": ts_code,
|
"股票代码": ts_code,
|
||||||
"股票名称": info.get("name", ""),
|
"股票名称": info.get("name", ""),
|
||||||
"行业": info.get("industry", ""),
|
"行业": info.get("industry", ""),
|
||||||
"综合评分": f"{score:.4f}"
|
"综合评分": score,
|
||||||
})
|
})
|
||||||
|
|
||||||
# 记录评估完成信息
|
# 记录评估完成信息
|
||||||
@ -372,36 +422,51 @@ def _calculate_stock_scores(
|
|||||||
|
|
||||||
|
|
||||||
def _add_to_stock_pool(
|
def _add_to_stock_pool(
|
||||||
ts_codes: List[str],
|
score_df: pd.DataFrame,
|
||||||
eval_date: datetime.date
|
eval_date: date
|
||||||
) -> None:
|
) -> None:
|
||||||
"""将股票添加到股票池。"""
|
"""将股票评分结果写入投资池。"""
|
||||||
with db_session() as session:
|
|
||||||
# 删除已有记录
|
trade_date = eval_date.strftime("%Y%m%d")
|
||||||
session.execute(
|
payload: List[tuple] = []
|
||||||
"""
|
ranked_df = score_df.reset_index(drop=True)
|
||||||
DELETE FROM stock_pool
|
|
||||||
WHERE entry_date = :entry_date
|
for rank, row in ranked_df.iterrows():
|
||||||
""",
|
tags = json.dumps(["stock_evaluation", "top20"], ensure_ascii=False)
|
||||||
{"entry_date": eval_date}
|
metadata = json.dumps(
|
||||||
)
|
|
||||||
|
|
||||||
# 插入新记录
|
|
||||||
values = [
|
|
||||||
{
|
{
|
||||||
"ts_code": code,
|
"source": "stock_evaluation",
|
||||||
"entry_date": eval_date,
|
"rank": rank + 1,
|
||||||
"entry_reason": "factor_evaluation"
|
"score": float(row["综合评分"]),
|
||||||
}
|
},
|
||||||
for code in ts_codes
|
ensure_ascii=False,
|
||||||
]
|
|
||||||
|
|
||||||
session.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO stock_pool (ts_code, entry_date, entry_reason)
|
|
||||||
VALUES (:ts_code, :entry_date, :entry_reason)
|
|
||||||
""",
|
|
||||||
values
|
|
||||||
)
|
)
|
||||||
|
payload.append(
|
||||||
session.commit()
|
(
|
||||||
|
trade_date,
|
||||||
|
row["股票代码"],
|
||||||
|
float(row["综合评分"]),
|
||||||
|
"candidate",
|
||||||
|
"factor_evaluation_top20",
|
||||||
|
tags,
|
||||||
|
metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with db_session() as conn:
|
||||||
|
conn.execute("DELETE FROM investment_pool WHERE trade_date = ?", (trade_date,))
|
||||||
|
if payload:
|
||||||
|
conn.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO investment_pool (
|
||||||
|
trade_date,
|
||||||
|
ts_code,
|
||||||
|
score,
|
||||||
|
status,
|
||||||
|
rationale,
|
||||||
|
tags,
|
||||||
|
metadata
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
payload,
|
||||||
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user