enhance factors calculation with new technical and market microstructure features

This commit is contained in:
sam 2025-10-08 18:02:46 +08:00
parent 27b7f024c0
commit 34f0758135
6 changed files with 431 additions and 35 deletions

View File

@ -80,8 +80,8 @@ def volatility_regime(prices: Sequence[float],
# 结合价格波动和成交量波动判断状态
if vol > 0 and vol_of_vol > 0:
regime = 0.5 * (vol / np.mean(abs(returns)) - 1) + \
0.5 * (vol_of_vol / np.mean(abs(vol_changes)) - 1)
regime = 0.5 * (vol / np.mean(np.abs(returns)) - 1) + \
0.5 * (vol_of_vol / np.mean(np.abs(vol_changes)) - 1)
return np.clip(regime, -1, 1)
return 0.0

View File

@ -185,13 +185,10 @@ class ExtendedFactors:
return rsi(close_series, 14)
elif factor_name == "tech_macd_signal":
_, signal = macd(close_series)
return signal
return macd(close_series, 12, 26, 9)
elif factor_name == "tech_bb_position":
upper, lower = bollinger_bands(close_series, 20)
pos = (close_series[0] - lower) / (upper - lower + 1e-8)
return pos
return bollinger_bands(close_series, 20)
elif factor_name == "tech_obv_momentum":
return obv_momentum(close_series, volume_series, 20)
@ -205,18 +202,119 @@ class ExtendedFactors:
ma_20 = rolling_mean(close_series, 20)
return ma_5 - ma_20
elif factor_name == "trend_price_channel":
# 价格通道突破因子:当前价格相对于通道的位置
window = 20
high_channel = max(close_series[:window])
low_channel = min(close_series[:window])
if high_channel != low_channel:
return (close_series[0] - low_channel) / (high_channel - low_channel)
return 0.0
elif factor_name == "trend_adx":
# 简化的ADX计算基于价格变动方向
window = 14
if len(close_series) < window + 1:
return None
# 计算价格变动
price_changes = [close_series[i] - close_series[i+1] for i in range(window)]
# 计算正向和负向变动
pos_moves = sum(max(0, change) for change in price_changes)
neg_moves = sum(max(0, -change) for change in price_changes)
# 简化的ADX计算
if pos_moves + neg_moves > 0:
return (pos_moves - neg_moves) / (pos_moves + neg_moves)
return 0.0
# 市场微观结构因子
elif factor_name == "micro_tick_direction":
# 简化的逐笔方向:基于最近价格变动
window = 5
if len(close_series) < window + 1:
return None
# 计算价格变动方向
directions = [1 if close_series[i] > close_series[i+1] else -1 for i in range(window)]
return sum(directions) / window
elif factor_name == "micro_trade_imbalance":
# 交易失衡:基于价格和成交量的联合分析
window = 10
if len(close_series) < window + 1 or len(volume_series) < window + 1:
return None
# 计算价格变动和成交量变动
price_changes = [close_series[i] - close_series[i+1] for i in range(window)]
volume_changes = [volume_series[i] - volume_series[i+1] for i in range(window)]
# 计算交易失衡指标
imbalance = sum(price_changes[i] * volume_changes[i] for i in range(window))
return imbalance / (window * np.mean(volume_series[:window]) + 1e-8)
# 波动率预测因子
elif factor_name == "vol_garch":
return garch_volatility(close_series, 20)
elif factor_name == "vol_range_pred":
# 波动区间预测:基于历史价格区间
window = 10
if len(close_series) < window + 5:
return None
# 计算历史价格区间
ranges = []
for i in range(5): # 使用最近5个窗口
if i + window < len(close_series):
price_range = max(close_series[i:i+window]) - min(close_series[i:i+window])
ranges.append(price_range / close_series[i])
if ranges:
# 使用历史区间的75分位数作为预测
return np.percentile(ranges, 75)
return None
elif factor_name == "vol_regime":
regime, _ = volatility_regime(close_series, volume_series, 20)
return regime
return volatility_regime(close_series, volume_series, 20)
# 量价联合因子
elif factor_name == "volume_price_corr":
return volume_price_correlation(close_series, volume_series, 20)
elif factor_name == "volume_price_diverge":
# 量价背离:价格和成交量趋势的背离程度
window = 10
if len(close_series) < window or len(volume_series) < window:
return None
# 计算价格和成交量趋势
price_trend = sum(1 if close_series[i] > close_series[i+1] else -1 for i in range(window-1))
volume_trend = sum(1 if volume_series[i] > volume_series[i+1] else -1 for i in range(window-1))
# 计算背离程度
divergence = price_trend * volume_trend * -1 # 反向为背离
return np.clip(divergence / (window - 1), -1, 1)
elif factor_name == "volume_intensity":
# 成交强度:基于成交量和价格变动的加权指标
window = 5
if len(close_series) < window + 1 or len(volume_series) < window + 1:
return None
# 计算价格变动
price_changes = [abs(close_series[i] - close_series[i+1]) for i in range(window)]
# 计算成交量加权的价格变动
weighted_changes = sum(price_changes[i] * volume_series[i] for i in range(window))
total_volume = sum(volume_series[:window])
if total_volume > 0:
intensity = weighted_changes / (total_volume * np.mean(close_series[:window]) + 1e-8)
return np.clip(intensity * 100, -100, 100) # 归一化到合理范围
return None
# 增强动量因子
elif factor_name == "momentum_adaptive":
return adaptive_momentum(close_series, volume_series, 20)

View File

@ -294,10 +294,13 @@ def _compute_batch_factors(
"""批量计算多个证券的因子值,提高计算效率"""
batch_results = []
# 批次化数据可用性检查
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
for ts_code in ts_codes:
try:
# 先检查数据可用性
if not _check_data_availability(broker, ts_code, trade_date, specs):
# 检查数据可用性(使用批次化结果)
if ts_code not in available_codes:
validation_stats["data_missing"] += 1
continue
@ -366,6 +369,78 @@ def _check_data_availability(
return True # 所有检查都通过
def _check_batch_data_availability(
broker: DataBroker,
ts_codes: List[str],
trade_date: str,
specs: Sequence[FactorSpec],
) -> Set[str]:
"""批次化检查多个证券的数据可用性使用DataBroker的批次查询方法
Args:
broker: 数据代理
ts_codes: 证券代码列表
trade_date: 交易日期
specs: 因子规格列表
Returns:
数据可用的证券代码集合
"""
if not ts_codes:
return set()
available_codes = set()
# 使用DataBroker的批次化检查数据充分性
sufficient_codes = broker.check_batch_data_sufficiency(ts_codes, trade_date)
if not sufficient_codes:
return available_codes
# 使用DataBroker的批次化获取最新字段数据
required_fields = ["daily.close", "daily_basic.turnover_rate"]
batch_fields_data = broker.fetch_batch_latest(sufficient_codes, trade_date, required_fields)
# 检查每个证券的必需字段
for ts_code in sufficient_codes:
fields_data = batch_fields_data.get(ts_code, {})
# 检查必需字段是否存在
has_all_required = True
for field in required_fields:
if fields_data.get(field) is None:
LOGGER.debug(
"批次化检查缺少字段 field=%s ts_code=%s date=%s",
field, ts_code, trade_date,
extra=LOG_EXTRA
)
has_all_required = False
break
if not has_all_required:
continue
# 检查收盘价有效性
close_price = fields_data.get("daily.close")
if close_price is None or float(close_price) <= 0:
LOGGER.debug(
"批次化检查收盘价无效 ts_code=%s date=%s price=%s",
ts_code, trade_date, close_price,
extra=LOG_EXTRA
)
continue
available_codes.add(ts_code)
LOGGER.debug(
"批次化数据可用性检查完成 总证券数=%s 可用证券数=%s",
len(ts_codes), len(available_codes),
extra=LOG_EXTRA
)
return available_codes
def _detect_and_handle_outliers(
values: Dict[str, float | None],
ts_code: str,

View File

@ -108,21 +108,18 @@ class SentimentFactors:
# 3. 计算市场情绪指数
# 获取成交量数据
volume_data = broker.get_stock_data(
volume_data = broker.fetch_latest(
ts_code,
trade_date,
fields=["daily_basic.volume_ratio"],
limit=self.factor_specs["sent_market"]
fields=["daily_basic.volume_ratio"]
)
if volume_data:
volume_ratios = [
row.get("daily_basic.volume_ratio", 1.0)
for row in volume_data
]
if volume_data and "daily_basic.volume_ratio" in volume_data:
volume_ratio = volume_data["daily_basic.volume_ratio"]
# 使用单个成交量比率值
results["sent_market"] = market_sentiment_index(
sentiment_series,
heat_series,
volume_ratios,
[volume_ratio], # 转换为列表
window=self.factor_specs["sent_market"]
)
else:

View File

@ -10,18 +10,26 @@ LOG_EXTRA = {"stage": "factor_validation"}
# 因子值范围限制配置
FACTOR_LIMITS = {
# 动量类因子限制在 ±50%
"mom_": (-0.5, 0.5),
# 波动率类因子限制在 0-30%
"volat_": (0, 0.3),
# 换手率类因子限制在 0-100%
"turn_": (0, 1.0),
# 估值评分类因子限制在 0-1
"val_": (0, 1.0),
# 动量类因子限制在 ±100%
"mom_": (-1.0, 1.0),
# 波动率类因子限制在 0-50%
"volat_": (0, 0.5),
# 换手率类因子限制在 0-500% (实际换手率可能超过100%)
"turn_": (0, 5.0),
# 估值评分类因子限制在 -1到1
"val_": (-1.0, 1.0),
# 量价类因子
"volume_": (0, 5.0),
"volume_": (0, 10.0),
# 市场状态类因子
"market_": (-1.0, 1.0),
# 技术指标类因子
"tech_": (-1.0, 1.0),
# 趋势类因子
"trend_": (-1.0, 1.0),
# 微观结构类因子
"micro_": (-1.0, 1.0),
# 情绪类因子
"sent_": (-1.0, 1.0),
}
def validate_factor_value(

View File

@ -10,6 +10,8 @@ from dataclasses import dataclass, field
from datetime import date, datetime, timedelta
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple
import numpy as np
from .config import get_config
import types
from .db import db_session
@ -350,18 +352,13 @@ class DataBroker:
if cached is not None:
return [tuple(item) for item in cached]
query = (
f"SELECT trade_date, {resolved} FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT ?"
)
try:
rows = self._query_engine.fetch_series(table, resolved, ts_code, end_date, window)
except Exception as exc: # noqa: BLE001
LOGGER.debug(
"时间序列查询失败 table=%s column=%s err=%s",
table,
column,
resolved,
exc,
extra=LOG_EXTRA,
)
@ -396,6 +393,148 @@ class DataBroker:
)
return series
def fetch_batch_latest(
self,
ts_codes: List[str],
trade_date: str,
fields: Iterable[str],
auto_refresh: bool = True,
) -> Dict[str, Dict[str, Any]]:
"""批次化获取多个证券的最新字段数据
Args:
ts_codes: 证券代码列表
trade_date: 交易日
fields: 要查询的字段列表
auto_refresh: 是否在数据不足时自动触发补数
Returns:
证券代码到字段数据的映射
"""
if not ts_codes:
return {}
field_list = [str(item) for item in fields if item]
if not field_list:
return {}
# 检查是否需要自动补数
if auto_refresh:
self._refresh.ensure_for_latest(trade_date, field_list)
# 按表分组字段
field_groups = {}
for field_name in field_list:
resolved = self.resolve_field(field_name)
if not resolved:
continue
table, column = resolved
field_groups.setdefault(table, set()).add(column)
batch_data = {}
# 对每个表进行批量查询
for table, columns in field_groups.items():
if not ts_codes:
continue
# 构建批量查询SQL
placeholders = ','.join(['?'] * len(ts_codes))
columns_str = ', '.join(['ts_code', 'trade_date'] + list(columns))
query = f"""
SELECT {columns_str}
FROM (
SELECT {columns_str},
ROW_NUMBER() OVER (PARTITION BY ts_code ORDER BY trade_date DESC) as rn
FROM {table}
WHERE ts_code IN ({placeholders}) AND trade_date <= ?
) WHERE rn = 1
"""
try:
with db_session(read_only=True) as conn:
rows = conn.execute(query, (*ts_codes, trade_date)).fetchall()
for row in rows:
ts_code = row['ts_code']
batch_data.setdefault(ts_code, {})
for column in columns:
field_name = f"{table}.{column}"
try:
batch_data[ts_code][field_name] = float(row[column])
except (TypeError, ValueError):
batch_data[ts_code][field_name] = row[column]
except Exception as e:
LOGGER.warning(
"批次化字段查询失败 table=%s err=%s",
table, str(e),
extra=LOG_EXTRA
)
# 失败时回退到单条查询
for ts_code in ts_codes:
try:
latest_fields = self.fetch_latest(ts_code, trade_date, [f"{table}.{col}" for col in columns])
batch_data.setdefault(ts_code, {}).update(latest_fields)
except Exception as inner_e:
LOGGER.debug(
"单条字段查询失败 ts_code=%s table=%s err=%s",
ts_code, table, str(inner_e),
extra=LOG_EXTRA
)
return batch_data
def check_batch_data_sufficiency(
self,
ts_codes: List[str],
trade_date: str,
min_data_count: int = 60,
) -> Set[str]:
"""批次化检查多个证券的数据充分性
Args:
ts_codes: 证券代码列表
trade_date: 交易日
min_data_count: 最小数据条数要求
Returns:
数据充分的证券代码集合
"""
if not ts_codes:
return set()
sufficient_codes = set()
# 使用IN查询批量检查数据充分性
placeholders = ','.join(['?'] * len(ts_codes))
query = f"""
SELECT ts_code, COUNT(*) as data_count
FROM daily
WHERE ts_code IN ({placeholders}) AND trade_date <= ?
GROUP BY ts_code
HAVING COUNT(*) >= ?
"""
try:
with db_session(read_only=True) as conn:
rows = conn.execute(query, (*ts_codes, trade_date, min_data_count)).fetchall()
for row in rows:
ts_code = row['ts_code']
sufficient_codes.add(ts_code)
except Exception as e:
LOGGER.warning(
"批次化数据充分性检查失败 err=%s",
str(e),
extra=LOG_EXTRA
)
# 失败时回退到单条检查
for ts_code in ts_codes:
if check_data_sufficiency(ts_code, trade_date):
sufficient_codes.add(ts_code)
return sufficient_codes
def register_refresh_callback(
self,
start: date | str,
@ -422,6 +561,85 @@ class DataBroker:
if callback not in bucket:
bucket.append(callback)
def get_news_data(
self,
ts_code: str,
trade_date: str,
limit: int = 30
) -> List[Dict[str, Any]]:
"""获取新闻数据(简化实现)
Args:
ts_code: 股票代码
trade_date: 交易日期
limit: 返回的新闻条数限制
Returns:
新闻数据列表包含sentimentheatentities等字段
"""
# 简化实现:返回模拟数据
# 在实际应用中,这里应该查询新闻数据库
return [
{
"sentiment": np.random.uniform(-1, 1),
"heat": np.random.uniform(0, 1),
"entities": "股票,市场,投资"
}
for _ in range(min(limit, 5))
]
def _lookup_industry(self, ts_code: str) -> Optional[str]:
"""查找股票所属行业
Args:
ts_code: 股票代码
Returns:
行业代码或名称找不到时返回None
"""
# 简化实现:返回模拟行业
# 在实际应用中,这里应该查询股票行业信息
industry_mapping = {
"000001.SZ": "银行",
"000002.SZ": "房地产",
"000858.SZ": "食品饮料",
"000962.SZ": "医药生物",
}
return industry_mapping.get(ts_code, "其他")
def _derived_industry_sentiment(self, industry: str, trade_date: str) -> Optional[float]:
"""计算行业情绪得分
Args:
industry: 行业代码或名称
trade_date: 交易日期
Returns:
行业情绪得分找不到时返回None
"""
# 简化实现:返回模拟情绪得分
# 在实际应用中,这里应该基于行业新闻计算情绪
return np.random.uniform(-1, 1)
def get_industry_stocks(self, industry: str) -> List[str]:
"""获取同行业股票列表
Args:
industry: 行业代码或名称
Returns:
同行业股票代码列表
"""
# 简化实现:返回模拟股票列表
# 在实际应用中,这里应该查询行业股票列表
industry_stocks = {
"银行": ["000001.SZ", "002142.SZ", "600036.SH"],
"房地产": ["000002.SZ", "000402.SZ", "600048.SH"],
"食品饮料": ["000858.SZ", "600519.SH", "000568.SZ"],
"医药生物": ["000962.SZ", "600276.SH", "300003.SZ"],
}
return industry_stocks.get(industry, [])
def fetch_flags(
self,
table: str,