llm-quant/app/features/factors.py
2025-10-05 10:22:25 +08:00

647 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Feature engineering for signals and indicator computation."""
from __future__ import annotations
import re
from dataclasses import dataclass
from datetime import datetime, date, timezone
from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
from app.core.indicators import momentum, rolling_mean, volatility
from app.data.schema import initialize_database
from app.utils.data_access import DataBroker
from app.utils.db import db_session
from app.utils.logging import get_logger
# 导入扩展因子模块
from app.features.extended_factors import compute_extended_factor_values
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "factor_compute"}
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
@dataclass
class FactorSpec:
name: str
window: int
@dataclass
class FactorResult:
ts_code: str
trade_date: date
values: Dict[str, float | None]
# 基础因子和扩展因子的完整列表
DEFAULT_FACTORS: List[FactorSpec] = [
# 基础动量因子
FactorSpec("mom_5", 5),
FactorSpec("mom_20", 20),
FactorSpec("mom_60", 60),
# 波动率因子
FactorSpec("volat_20", 20),
# 换手率因子
FactorSpec("turn_20", 20),
FactorSpec("turn_5", 5),
# 估值因子
FactorSpec("val_pe_score", 0),
FactorSpec("val_pb_score", 0),
# 量比因子
FactorSpec("volume_ratio_score", 0),
# 扩展因子
# 增强动量因子
FactorSpec("mom_10_30", 0), # 10日与30日动量差
FactorSpec("mom_5_20_rank", 0), # 相对排名动量因子
FactorSpec("mom_dynamic", 0), # 动态窗口动量因子
# 波动率相关因子
FactorSpec("volat_5", 5), # 短期波动率
FactorSpec("volat_ratio", 0), # 长短期波动率比率
# 换手率扩展因子
FactorSpec("turn_60", 60), # 长期换手率
FactorSpec("turn_rank", 0), # 换手率相对排名
# 价格均线比率因子
FactorSpec("price_ma_10_ratio", 0), # 当前价格与10日均线比率
FactorSpec("price_ma_20_ratio", 0), # 当前价格与20日均线比率
FactorSpec("price_ma_60_ratio", 0), # 当前价格与60日均线比率
# 成交量均线比率因子
FactorSpec("volume_ma_5_ratio", 0), # 当前成交量与5日均线比率
FactorSpec("volume_ma_20_ratio", 0), # 当前成交量与20日均线比率
# 高级估值因子
FactorSpec("val_ps_score", 0), # PS估值评分
FactorSpec("val_multiscore", 0), # 综合估值评分
FactorSpec("val_dividend_score", 0), # 股息率估值评分
# 市场状态因子
FactorSpec("market_regime", 0), # 市场状态因子
FactorSpec("trend_strength", 0), # 趋势强度因子
]
def compute_factors(
trade_date: date,
factors: Iterable[FactorSpec] = DEFAULT_FACTORS,
*,
ts_codes: Optional[Sequence[str]] = None,
skip_existing: bool = False,
batch_size: int = 100,
) -> List[FactorResult]:
"""Calculate and persist factor values for the requested date.
``ts_codes`` can be supplied to restrict computation to a subset of the
universe. When ``skip_existing`` is True, securities that already have an
entry for ``trade_date`` will be ignored.
Args:
trade_date: 交易日日期
factors: 要计算的因子列表
ts_codes: 可选,限制计算的证券代码列表
skip_existing: 是否跳过已存在的因子值
batch_size: 批处理大小,用于优化性能
Returns:
因子计算结果列表
"""
specs = [spec for spec in factors if spec.window >= 0]
if not specs:
return []
initialize_database()
trade_date_str = trade_date.strftime("%Y%m%d")
_ensure_factor_columns(specs)
allowed = {code.strip().upper() for code in ts_codes or () if code.strip()}
universe = _load_universe(trade_date_str, allowed if allowed else None)
if not universe:
LOGGER.info("无可用标的生成因子 trade_date=%s", trade_date_str, extra=LOG_EXTRA)
return []
if skip_existing:
existing = _existing_factor_codes(trade_date_str)
universe = [code for code in universe if code not in existing]
if not universe:
LOGGER.debug(
"目标交易日因子已存在 trade_date=%s universe_size=%s",
trade_date_str,
len(existing),
extra=LOG_EXTRA,
)
return []
LOGGER.info(
"开始计算因子 universe_size=%s factors=%s trade_date=%s",
len(universe),
[spec.name for spec in specs],
trade_date_str,
extra=LOG_EXTRA,
)
# 数据有效性校验初始化
validation_stats = {
"total": len(universe),
"skipped": 0,
"success": 0,
"data_missing": 0,
"outliers": 0
}
broker = DataBroker()
results: List[FactorResult] = []
rows_to_persist: List[tuple[str, Dict[str, float | None]]] = []
# 分批处理以优化性能
for i in range(0, len(universe), batch_size):
batch = universe[i:i+batch_size]
batch_results = _compute_batch_factors(broker, batch, trade_date_str, specs, validation_stats)
for ts_code, values in batch_results:
if values:
results.append(FactorResult(ts_code=ts_code, trade_date=trade_date, values=values))
rows_to_persist.append((ts_code, values))
# 显示进度
processed = min(i + batch_size, len(universe))
if processed % (batch_size * 5) == 0 or processed == len(universe):
LOGGER.info(
"因子计算进度: %s/%s (%.1f%%) 成功:%s 跳过:%s 数据缺失:%s 异常值:%s",
processed, len(universe),
(processed / len(universe)) * 100,
validation_stats["success"],
validation_stats["skipped"],
validation_stats["data_missing"],
validation_stats["outliers"],
extra=LOG_EXTRA,
)
if rows_to_persist:
_persist_factor_rows(trade_date_str, rows_to_persist, specs)
LOGGER.info(
"因子计算完成 总数量:%s 成功:%s 失败:%s",
len(universe),
validation_stats["success"],
validation_stats["total"] - validation_stats["success"],
extra=LOG_EXTRA,
)
return results
def compute_factor_range(
start: date,
end: date,
*,
factors: Iterable[FactorSpec] = DEFAULT_FACTORS,
ts_codes: Optional[Sequence[str]] = None,
skip_existing: bool = True,
) -> List[FactorResult]:
"""Compute factors for all trading days within ``[start, end]`` inclusive."""
if end < start:
raise ValueError("end date must not precede start date")
initialize_database()
allowed = None
if ts_codes:
allowed = tuple(dict.fromkeys(code.strip().upper() for code in ts_codes if code.strip()))
if not allowed:
allowed = None
start_str = start.strftime("%Y%m%d")
end_str = end.strftime("%Y%m%d")
trade_dates = _list_trade_dates(start_str, end_str, allowed)
aggregated: List[FactorResult] = []
for trade_date_str in trade_dates:
trade_day = datetime.strptime(trade_date_str, "%Y%m%d").date()
aggregated.extend(
compute_factors(
trade_day,
factors,
ts_codes=allowed,
skip_existing=skip_existing,
)
)
return aggregated
def _load_universe(trade_date: str, allowed: Optional[set[str]] = None) -> List[str]:
query = "SELECT ts_code FROM daily WHERE trade_date = ? ORDER BY ts_code"
with db_session(read_only=True) as conn:
rows = conn.execute(query, (trade_date,)).fetchall()
codes = [row["ts_code"] for row in rows if row["ts_code"]]
if allowed:
allowed_upper = {code.upper() for code in allowed}
return [code for code in codes if code.upper() in allowed_upper]
return codes
def _existing_factor_codes(trade_date: str) -> set[str]:
with db_session(read_only=True) as conn:
rows = conn.execute(
"SELECT ts_code FROM factors WHERE trade_date = ?",
(trade_date,),
).fetchall()
return {row["ts_code"] for row in rows if row["ts_code"]}
def _list_trade_dates(
start_date: str,
end_date: str,
allowed: Optional[Sequence[str]],
) -> List[str]:
params: List[str] = [start_date, end_date]
if allowed:
placeholders = ", ".join("?" for _ in allowed)
query = (
"SELECT DISTINCT trade_date FROM daily "
"WHERE trade_date BETWEEN ? AND ? "
f"AND ts_code IN ({placeholders}) "
"ORDER BY trade_date"
)
params.extend(allowed)
else:
query = (
"SELECT DISTINCT trade_date FROM daily "
"WHERE trade_date BETWEEN ? AND ? "
"ORDER BY trade_date"
)
with db_session(read_only=True) as conn:
rows = conn.execute(query, params).fetchall()
return [row["trade_date"] for row in rows if row["trade_date"]]
def _compute_batch_factors(
broker: DataBroker,
ts_codes: List[str],
trade_date: str,
specs: Sequence[FactorSpec],
validation_stats: Dict[str, int],
) -> List[tuple[str, Dict[str, float | None]]]:
"""批量计算多个证券的因子值,提高计算效率"""
batch_results = []
for ts_code in ts_codes:
try:
# 先检查数据可用性
if not _check_data_availability(broker, ts_code, trade_date, specs):
validation_stats["data_missing"] += 1
continue
# 计算因子值
values = _compute_security_factors(broker, ts_code, trade_date, specs)
if values:
# 检测并处理异常值
values = _detect_and_handle_outliers(values, ts_code)
batch_results.append((ts_code, values))
validation_stats["success"] += 1
else:
validation_stats["skipped"] += 1
except Exception as e:
LOGGER.error(
"计算因子失败 ts_code=%s err=%s",
ts_code,
str(e),
extra=LOG_EXTRA,
)
validation_stats["skipped"] += 1
return batch_results
def _check_data_availability(
broker: DataBroker,
ts_code: str,
trade_date: str,
specs: Sequence[FactorSpec],
) -> bool:
"""检查证券数据是否足够计算所有请求的因子"""
# 获取最小需要的数据天数
min_days = 1 # 至少需要当天的数据
for spec in specs:
if spec.window > min_days:
min_days = spec.window
# 检查基本数据是否存在
basic_data = broker.fetch_latest(
ts_code,
trade_date,
["daily.close", "daily_basic.turnover_rate"],
)
# 检查时间序列数据是否足够
close_check = broker.fetch_series("daily", "close", ts_code, trade_date, min_days)
return (
bool(basic_data.get("daily.close")) and
len(close_check) >= min_days
)
def _detect_and_handle_outliers(
values: Dict[str, float | None],
ts_code: str,
) -> Dict[str, float | None]:
"""检测并处理因子值中的异常值"""
result = values.copy()
outliers_found = False
# 动量因子异常值检测
for key in [k for k in values if k.startswith("mom_") and values[k] is not None]:
value = values[key]
# 异常值检测规则动量值绝对值大于3视为异常
if abs(value) > 3.0:
LOGGER.debug(
"检测到动量因子异常值 ts_code=%s factor=%s value=%.4f",
ts_code, key, value,
extra=LOG_EXTRA,
)
# 限制到合理范围
result[key] = min(3.0, max(-3.0, value))
outliers_found = True
# 波动率因子异常值检测
for key in [k for k in values if k.startswith("volat_") and values[k] is not None]:
value = values[key]
# 异常值检测规则波动率大于100%视为异常
if value > 1.0:
LOGGER.debug(
"检测到波动率因子异常值 ts_code=%s factor=%s value=%.4f",
ts_code, key, value,
extra=LOG_EXTRA,
)
# 限制到合理范围
result[key] = min(1.0, value)
outliers_found = True
if outliers_found:
LOGGER.debug(
"处理后因子值 ts_code=%s values=%s",
ts_code, {k: f"{v:.4f}" for k, v in result.items() if v is not None},
extra=LOG_EXTRA,
)
return result
def _compute_security_factors(
broker: DataBroker,
ts_code: str,
trade_date: str,
specs: Sequence[FactorSpec],
) -> Dict[str, float | None]:
"""计算单个证券的因子值"""
# 确定所需的最大窗口大小
close_windows = [spec.window for spec in specs if _factor_prefix(spec.name) in {"mom", "volat"}]
turnover_windows = [spec.window for spec in specs if _factor_prefix(spec.name) == "turn"]
max_close_window = max(close_windows) if close_windows else 0
max_turn_window = max(turnover_windows) if turnover_windows else 0
# 获取所需的时间序列数据
close_series = _fetch_series_values(
broker,
"daily",
"close",
ts_code,
trade_date,
max_close_window,
)
# 数据有效性检查
if not close_series:
LOGGER.debug("缺少收盘价数据 ts_code=%s", ts_code, extra=LOG_EXTRA)
return {}
turnover_series = _fetch_series_values(
broker,
"daily_basic",
"turnover_rate",
ts_code,
trade_date,
max_turn_window,
)
# 获取成交量数据用于扩展因子计算
volume_series = _fetch_series_values(
broker,
"daily",
"vol",
ts_code,
trade_date,
max_close_window, # 使用与价格相同的窗口
)
# 获取最新字段值
latest_fields = broker.fetch_latest(
ts_code,
trade_date,
[
"daily_basic.pe",
"daily_basic.pb",
"daily_basic.ps",
"daily_basic.volume_ratio",
"daily.amount",
"daily.vol",
"daily_basic.dv_ratio", # 股息率用于扩展因子
],
)
# 计算各个因子值
results: Dict[str, float | None] = {}
for spec in specs:
prefix = _factor_prefix(spec.name)
if prefix == "mom":
if len(close_series) >= spec.window:
results[spec.name] = momentum(close_series, spec.window)
else:
results[spec.name] = None
elif prefix == "volat":
if len(close_series) >= 2:
results[spec.name] = volatility(close_series, spec.window)
else:
results[spec.name] = None
elif prefix == "turn":
if len(turnover_series) >= spec.window:
results[spec.name] = rolling_mean(turnover_series, spec.window)
else:
results[spec.name] = None
elif spec.name == "val_pe_score":
pe = latest_fields.get("daily_basic.pe")
results[spec.name] = _valuation_score(pe, scale=12.0)
elif spec.name == "val_pb_score":
pb = latest_fields.get("daily_basic.pb")
results[spec.name] = _valuation_score(pb, scale=2.5)
elif spec.name == "volume_ratio_score":
volume_ratio = latest_fields.get("daily_basic.volume_ratio")
results[spec.name] = _volume_ratio_score(volume_ratio)
else:
LOGGER.debug(
"忽略未识别的因子 name=%s ts_code=%s",
spec.name,
ts_code,
extra=LOG_EXTRA,
)
# 计算扩展因子值
extended_factors = compute_extended_factor_values(
close_series, volume_series, turnover_series, latest_fields
)
results.update(extended_factors)
# 确保返回结果不为空
if not any(v is not None for v in results.values()):
return {}
return results
def _persist_factor_rows(
trade_date: str,
rows: Sequence[tuple[str, Dict[str, float | None]]],
specs: Sequence[FactorSpec],
) -> None:
"""优化的因子结果持久化函数,支持批量写入"""
if not rows:
return
columns = sorted({spec.name for spec in specs})
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
# SQL语句准备
insert_columns = ["ts_code", "trade_date", "updated_at", *columns]
placeholders = ", ".join(["?"] * len(insert_columns))
update_clause = ", ".join(
f"{column}=excluded.{column}" for column in ["updated_at", *columns]
)
sql = (
f"INSERT INTO factors ({', '.join(insert_columns)}) "
f"VALUES ({placeholders}) "
f"ON CONFLICT(ts_code, trade_date) DO UPDATE SET {update_clause}"
)
# 准备批量写入数据
batch_size = 500 # 批处理大小
batch_payloads = []
for ts_code, values in rows:
# 过滤掉全部为None的行
if not any(values.get(col) is not None for col in columns):
continue
payload = [ts_code, trade_date, timestamp]
payload.extend(values.get(column) for column in columns)
batch_payloads.append(payload)
if not batch_payloads:
LOGGER.debug("无可持久化的有效因子数据", extra=LOG_EXTRA)
return
# 执行批量写入
total_inserted = 0
with db_session() as conn:
# 分批执行以避免SQLite参数限制
for i in range(0, len(batch_payloads), batch_size):
batch = batch_payloads[i:i+batch_size]
try:
conn.executemany(sql, batch)
batch_count = len(batch)
total_inserted += batch_count
if batch_count % (batch_size * 5) == 0:
LOGGER.debug(
"因子数据持久化进度: %s/%s",
min(i + batch_size, len(batch_payloads)),
len(batch_payloads),
extra=LOG_EXTRA,
)
except sqlite3.Error as e:
LOGGER.error(
"因子数据持久化失败 批次=%s-%s err=%s",
i, min(i + batch_size, len(batch_payloads)),
str(e),
extra=LOG_EXTRA,
)
LOGGER.info(
"因子数据持久化完成 写入记录数=%s 总记录数=%s",
total_inserted,
len(batch_payloads),
extra=LOG_EXTRA,
)
def _ensure_factor_columns(specs: Sequence[FactorSpec]) -> None:
pending = {spec.name for spec in specs if _IDENTIFIER_RE.match(spec.name)}
if not pending:
return
with db_session() as conn:
existing_rows = conn.execute("PRAGMA table_info(factors)").fetchall()
existing = {row["name"] for row in existing_rows}
for column in sorted(pending - existing):
conn.execute(f"ALTER TABLE factors ADD COLUMN {column} REAL")
def _fetch_series_values(
broker: DataBroker,
table: str,
column: str,
ts_code: str,
trade_date: str,
window: int,
) -> List[float]:
if window <= 0:
return []
series = broker.fetch_series(table, column, ts_code, trade_date, window)
values: List[float] = []
for _dt, raw in series:
try:
values.append(float(raw))
except (TypeError, ValueError):
continue
return values
def _factor_prefix(name: str) -> str:
return name.split("_", 1)[0] if name else ""
def _valuation_score(value: object, *, scale: float) -> float:
"""计算估值指标的标准化分数"""
try:
numeric = float(value)
except (TypeError, ValueError):
return 0.0
# 有效性检查
if numeric <= 0:
return 0.0
# 异常值处理:限制估值指标的上限
max_limit = scale * 10 # 设置十倍scale为上限
if numeric > max_limit:
numeric = max_limit
# 计算分数
score = scale / (scale + numeric)
return max(0.0, min(1.0, score))
def _volume_ratio_score(value: object) -> float:
"""计算量比指标的标准化分数"""
try:
numeric = float(value)
except (TypeError, ValueError):
return 0.0
# 有效性检查
if numeric < 0:
numeric = 0.0
# 异常值处理设置量比上限为20
if numeric > 20:
numeric = 20
return max(0.0, min(1.0, numeric / 10.0))