enhance factors calculation with new technical and market microstructure features
This commit is contained in:
parent
27b7f024c0
commit
34f0758135
@ -80,8 +80,8 @@ def volatility_regime(prices: Sequence[float],
|
||||
|
||||
# 结合价格波动和成交量波动判断状态
|
||||
if vol > 0 and vol_of_vol > 0:
|
||||
regime = 0.5 * (vol / np.mean(abs(returns)) - 1) + \
|
||||
0.5 * (vol_of_vol / np.mean(abs(vol_changes)) - 1)
|
||||
regime = 0.5 * (vol / np.mean(np.abs(returns)) - 1) + \
|
||||
0.5 * (vol_of_vol / np.mean(np.abs(vol_changes)) - 1)
|
||||
return np.clip(regime, -1, 1)
|
||||
return 0.0
|
||||
|
||||
|
||||
@ -185,13 +185,10 @@ class ExtendedFactors:
|
||||
return rsi(close_series, 14)
|
||||
|
||||
elif factor_name == "tech_macd_signal":
|
||||
_, signal = macd(close_series)
|
||||
return signal
|
||||
return macd(close_series, 12, 26, 9)
|
||||
|
||||
elif factor_name == "tech_bb_position":
|
||||
upper, lower = bollinger_bands(close_series, 20)
|
||||
pos = (close_series[0] - lower) / (upper - lower + 1e-8)
|
||||
return pos
|
||||
return bollinger_bands(close_series, 20)
|
||||
|
||||
elif factor_name == "tech_obv_momentum":
|
||||
return obv_momentum(close_series, volume_series, 20)
|
||||
@ -204,19 +201,120 @@ class ExtendedFactors:
|
||||
ma_5 = rolling_mean(close_series, 5)
|
||||
ma_20 = rolling_mean(close_series, 20)
|
||||
return ma_5 - ma_20
|
||||
|
||||
elif factor_name == "trend_price_channel":
|
||||
# 价格通道突破因子:当前价格相对于通道的位置
|
||||
window = 20
|
||||
high_channel = max(close_series[:window])
|
||||
low_channel = min(close_series[:window])
|
||||
if high_channel != low_channel:
|
||||
return (close_series[0] - low_channel) / (high_channel - low_channel)
|
||||
return 0.0
|
||||
|
||||
elif factor_name == "trend_adx":
|
||||
# 简化的ADX计算:基于价格变动方向
|
||||
window = 14
|
||||
if len(close_series) < window + 1:
|
||||
return None
|
||||
|
||||
# 计算价格变动
|
||||
price_changes = [close_series[i] - close_series[i+1] for i in range(window)]
|
||||
|
||||
# 计算正向和负向变动
|
||||
pos_moves = sum(max(0, change) for change in price_changes)
|
||||
neg_moves = sum(max(0, -change) for change in price_changes)
|
||||
|
||||
# 简化的ADX计算
|
||||
if pos_moves + neg_moves > 0:
|
||||
return (pos_moves - neg_moves) / (pos_moves + neg_moves)
|
||||
return 0.0
|
||||
|
||||
# 市场微观结构因子
|
||||
elif factor_name == "micro_tick_direction":
|
||||
# 简化的逐笔方向:基于最近价格变动
|
||||
window = 5
|
||||
if len(close_series) < window + 1:
|
||||
return None
|
||||
|
||||
# 计算价格变动方向
|
||||
directions = [1 if close_series[i] > close_series[i+1] else -1 for i in range(window)]
|
||||
return sum(directions) / window
|
||||
|
||||
elif factor_name == "micro_trade_imbalance":
|
||||
# 交易失衡:基于价格和成交量的联合分析
|
||||
window = 10
|
||||
if len(close_series) < window + 1 or len(volume_series) < window + 1:
|
||||
return None
|
||||
|
||||
# 计算价格变动和成交量变动
|
||||
price_changes = [close_series[i] - close_series[i+1] for i in range(window)]
|
||||
volume_changes = [volume_series[i] - volume_series[i+1] for i in range(window)]
|
||||
|
||||
# 计算交易失衡指标
|
||||
imbalance = sum(price_changes[i] * volume_changes[i] for i in range(window))
|
||||
return imbalance / (window * np.mean(volume_series[:window]) + 1e-8)
|
||||
|
||||
# 波动率预测因子
|
||||
elif factor_name == "vol_garch":
|
||||
return garch_volatility(close_series, 20)
|
||||
|
||||
elif factor_name == "vol_range_pred":
|
||||
# 波动区间预测:基于历史价格区间
|
||||
window = 10
|
||||
if len(close_series) < window + 5:
|
||||
return None
|
||||
|
||||
# 计算历史价格区间
|
||||
ranges = []
|
||||
for i in range(5): # 使用最近5个窗口
|
||||
if i + window < len(close_series):
|
||||
price_range = max(close_series[i:i+window]) - min(close_series[i:i+window])
|
||||
ranges.append(price_range / close_series[i])
|
||||
|
||||
if ranges:
|
||||
# 使用历史区间的75分位数作为预测
|
||||
return np.percentile(ranges, 75)
|
||||
return None
|
||||
|
||||
elif factor_name == "vol_regime":
|
||||
regime, _ = volatility_regime(close_series, volume_series, 20)
|
||||
return regime
|
||||
return volatility_regime(close_series, volume_series, 20)
|
||||
|
||||
# 量价联合因子
|
||||
elif factor_name == "volume_price_corr":
|
||||
return volume_price_correlation(close_series, volume_series, 20)
|
||||
|
||||
elif factor_name == "volume_price_diverge":
|
||||
# 量价背离:价格和成交量趋势的背离程度
|
||||
window = 10
|
||||
if len(close_series) < window or len(volume_series) < window:
|
||||
return None
|
||||
|
||||
# 计算价格和成交量趋势
|
||||
price_trend = sum(1 if close_series[i] > close_series[i+1] else -1 for i in range(window-1))
|
||||
volume_trend = sum(1 if volume_series[i] > volume_series[i+1] else -1 for i in range(window-1))
|
||||
|
||||
# 计算背离程度
|
||||
divergence = price_trend * volume_trend * -1 # 反向为背离
|
||||
return np.clip(divergence / (window - 1), -1, 1)
|
||||
|
||||
elif factor_name == "volume_intensity":
|
||||
# 成交强度:基于成交量和价格变动的加权指标
|
||||
window = 5
|
||||
if len(close_series) < window + 1 or len(volume_series) < window + 1:
|
||||
return None
|
||||
|
||||
# 计算价格变动
|
||||
price_changes = [abs(close_series[i] - close_series[i+1]) for i in range(window)]
|
||||
|
||||
# 计算成交量加权的价格变动
|
||||
weighted_changes = sum(price_changes[i] * volume_series[i] for i in range(window))
|
||||
total_volume = sum(volume_series[:window])
|
||||
|
||||
if total_volume > 0:
|
||||
intensity = weighted_changes / (total_volume * np.mean(close_series[:window]) + 1e-8)
|
||||
return np.clip(intensity * 100, -100, 100) # 归一化到合理范围
|
||||
return None
|
||||
|
||||
# 增强动量因子
|
||||
elif factor_name == "momentum_adaptive":
|
||||
return adaptive_momentum(close_series, volume_series, 20)
|
||||
|
||||
@ -294,10 +294,13 @@ def _compute_batch_factors(
|
||||
"""批量计算多个证券的因子值,提高计算效率"""
|
||||
batch_results = []
|
||||
|
||||
# 批次化数据可用性检查
|
||||
available_codes = _check_batch_data_availability(broker, ts_codes, trade_date, specs)
|
||||
|
||||
for ts_code in ts_codes:
|
||||
try:
|
||||
# 先检查数据可用性
|
||||
if not _check_data_availability(broker, ts_code, trade_date, specs):
|
||||
# 检查数据可用性(使用批次化结果)
|
||||
if ts_code not in available_codes:
|
||||
validation_stats["data_missing"] += 1
|
||||
continue
|
||||
|
||||
@ -366,6 +369,78 @@ def _check_data_availability(
|
||||
return True # 所有检查都通过
|
||||
|
||||
|
||||
def _check_batch_data_availability(
|
||||
broker: DataBroker,
|
||||
ts_codes: List[str],
|
||||
trade_date: str,
|
||||
specs: Sequence[FactorSpec],
|
||||
) -> Set[str]:
|
||||
"""批次化检查多个证券的数据可用性,使用DataBroker的批次查询方法
|
||||
|
||||
Args:
|
||||
broker: 数据代理
|
||||
ts_codes: 证券代码列表
|
||||
trade_date: 交易日期
|
||||
specs: 因子规格列表
|
||||
|
||||
Returns:
|
||||
数据可用的证券代码集合
|
||||
"""
|
||||
if not ts_codes:
|
||||
return set()
|
||||
|
||||
available_codes = set()
|
||||
|
||||
# 使用DataBroker的批次化检查数据充分性
|
||||
sufficient_codes = broker.check_batch_data_sufficiency(ts_codes, trade_date)
|
||||
|
||||
if not sufficient_codes:
|
||||
return available_codes
|
||||
|
||||
# 使用DataBroker的批次化获取最新字段数据
|
||||
required_fields = ["daily.close", "daily_basic.turnover_rate"]
|
||||
batch_fields_data = broker.fetch_batch_latest(sufficient_codes, trade_date, required_fields)
|
||||
|
||||
# 检查每个证券的必需字段
|
||||
for ts_code in sufficient_codes:
|
||||
fields_data = batch_fields_data.get(ts_code, {})
|
||||
|
||||
# 检查必需字段是否存在
|
||||
has_all_required = True
|
||||
for field in required_fields:
|
||||
if fields_data.get(field) is None:
|
||||
LOGGER.debug(
|
||||
"批次化检查缺少字段 field=%s ts_code=%s date=%s",
|
||||
field, ts_code, trade_date,
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
has_all_required = False
|
||||
break
|
||||
|
||||
if not has_all_required:
|
||||
continue
|
||||
|
||||
# 检查收盘价有效性
|
||||
close_price = fields_data.get("daily.close")
|
||||
if close_price is None or float(close_price) <= 0:
|
||||
LOGGER.debug(
|
||||
"批次化检查收盘价无效 ts_code=%s date=%s price=%s",
|
||||
ts_code, trade_date, close_price,
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
continue
|
||||
|
||||
available_codes.add(ts_code)
|
||||
|
||||
LOGGER.debug(
|
||||
"批次化数据可用性检查完成 总证券数=%s 可用证券数=%s",
|
||||
len(ts_codes), len(available_codes),
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
|
||||
return available_codes
|
||||
|
||||
|
||||
def _detect_and_handle_outliers(
|
||||
values: Dict[str, float | None],
|
||||
ts_code: str,
|
||||
|
||||
@ -108,21 +108,18 @@ class SentimentFactors:
|
||||
|
||||
# 3. 计算市场情绪指数
|
||||
# 获取成交量数据
|
||||
volume_data = broker.get_stock_data(
|
||||
volume_data = broker.fetch_latest(
|
||||
ts_code,
|
||||
trade_date,
|
||||
fields=["daily_basic.volume_ratio"],
|
||||
limit=self.factor_specs["sent_market"]
|
||||
fields=["daily_basic.volume_ratio"]
|
||||
)
|
||||
if volume_data:
|
||||
volume_ratios = [
|
||||
row.get("daily_basic.volume_ratio", 1.0)
|
||||
for row in volume_data
|
||||
]
|
||||
if volume_data and "daily_basic.volume_ratio" in volume_data:
|
||||
volume_ratio = volume_data["daily_basic.volume_ratio"]
|
||||
# 使用单个成交量比率值
|
||||
results["sent_market"] = market_sentiment_index(
|
||||
sentiment_series,
|
||||
heat_series,
|
||||
volume_ratios,
|
||||
[volume_ratio], # 转换为列表
|
||||
window=self.factor_specs["sent_market"]
|
||||
)
|
||||
else:
|
||||
|
||||
@ -10,18 +10,26 @@ LOG_EXTRA = {"stage": "factor_validation"}
|
||||
|
||||
# 因子值范围限制配置
|
||||
FACTOR_LIMITS = {
|
||||
# 动量类因子限制在 ±50%
|
||||
"mom_": (-0.5, 0.5),
|
||||
# 波动率类因子限制在 0-30%
|
||||
"volat_": (0, 0.3),
|
||||
# 换手率类因子限制在 0-100%
|
||||
"turn_": (0, 1.0),
|
||||
# 估值评分类因子限制在 0-1
|
||||
"val_": (0, 1.0),
|
||||
# 动量类因子限制在 ±100%
|
||||
"mom_": (-1.0, 1.0),
|
||||
# 波动率类因子限制在 0-50%
|
||||
"volat_": (0, 0.5),
|
||||
# 换手率类因子限制在 0-500% (实际换手率可能超过100%)
|
||||
"turn_": (0, 5.0),
|
||||
# 估值评分类因子限制在 -1到1
|
||||
"val_": (-1.0, 1.0),
|
||||
# 量价类因子
|
||||
"volume_": (0, 5.0),
|
||||
"volume_": (0, 10.0),
|
||||
# 市场状态类因子
|
||||
"market_": (-1.0, 1.0),
|
||||
# 技术指标类因子
|
||||
"tech_": (-1.0, 1.0),
|
||||
# 趋势类因子
|
||||
"trend_": (-1.0, 1.0),
|
||||
# 微观结构类因子
|
||||
"micro_": (-1.0, 1.0),
|
||||
# 情绪类因子
|
||||
"sent_": (-1.0, 1.0),
|
||||
}
|
||||
|
||||
def validate_factor_value(
|
||||
|
||||
@ -10,6 +10,8 @@ from dataclasses import dataclass, field
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .config import get_config
|
||||
import types
|
||||
from .db import db_session
|
||||
@ -350,18 +352,13 @@ class DataBroker:
|
||||
if cached is not None:
|
||||
return [tuple(item) for item in cached]
|
||||
|
||||
query = (
|
||||
f"SELECT trade_date, {resolved} FROM {table} "
|
||||
"WHERE ts_code = ? AND trade_date <= ? "
|
||||
"ORDER BY trade_date DESC LIMIT ?"
|
||||
)
|
||||
try:
|
||||
rows = self._query_engine.fetch_series(table, resolved, ts_code, end_date, window)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.debug(
|
||||
"时间序列查询失败 table=%s column=%s err=%s",
|
||||
table,
|
||||
column,
|
||||
resolved,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
@ -396,6 +393,148 @@ class DataBroker:
|
||||
)
|
||||
return series
|
||||
|
||||
def fetch_batch_latest(
|
||||
self,
|
||||
ts_codes: List[str],
|
||||
trade_date: str,
|
||||
fields: Iterable[str],
|
||||
auto_refresh: bool = True,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""批次化获取多个证券的最新字段数据
|
||||
|
||||
Args:
|
||||
ts_codes: 证券代码列表
|
||||
trade_date: 交易日
|
||||
fields: 要查询的字段列表
|
||||
auto_refresh: 是否在数据不足时自动触发补数
|
||||
|
||||
Returns:
|
||||
证券代码到字段数据的映射
|
||||
"""
|
||||
if not ts_codes:
|
||||
return {}
|
||||
|
||||
field_list = [str(item) for item in fields if item]
|
||||
if not field_list:
|
||||
return {}
|
||||
|
||||
# 检查是否需要自动补数
|
||||
if auto_refresh:
|
||||
self._refresh.ensure_for_latest(trade_date, field_list)
|
||||
|
||||
# 按表分组字段
|
||||
field_groups = {}
|
||||
for field_name in field_list:
|
||||
resolved = self.resolve_field(field_name)
|
||||
if not resolved:
|
||||
continue
|
||||
table, column = resolved
|
||||
field_groups.setdefault(table, set()).add(column)
|
||||
|
||||
batch_data = {}
|
||||
|
||||
# 对每个表进行批量查询
|
||||
for table, columns in field_groups.items():
|
||||
if not ts_codes:
|
||||
continue
|
||||
|
||||
# 构建批量查询SQL
|
||||
placeholders = ','.join(['?'] * len(ts_codes))
|
||||
columns_str = ', '.join(['ts_code', 'trade_date'] + list(columns))
|
||||
|
||||
query = f"""
|
||||
SELECT {columns_str}
|
||||
FROM (
|
||||
SELECT {columns_str},
|
||||
ROW_NUMBER() OVER (PARTITION BY ts_code ORDER BY trade_date DESC) as rn
|
||||
FROM {table}
|
||||
WHERE ts_code IN ({placeholders}) AND trade_date <= ?
|
||||
) WHERE rn = 1
|
||||
"""
|
||||
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
rows = conn.execute(query, (*ts_codes, trade_date)).fetchall()
|
||||
for row in rows:
|
||||
ts_code = row['ts_code']
|
||||
batch_data.setdefault(ts_code, {})
|
||||
|
||||
for column in columns:
|
||||
field_name = f"{table}.{column}"
|
||||
try:
|
||||
batch_data[ts_code][field_name] = float(row[column])
|
||||
except (TypeError, ValueError):
|
||||
batch_data[ts_code][field_name] = row[column]
|
||||
except Exception as e:
|
||||
LOGGER.warning(
|
||||
"批次化字段查询失败 table=%s err=%s",
|
||||
table, str(e),
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
# 失败时回退到单条查询
|
||||
for ts_code in ts_codes:
|
||||
try:
|
||||
latest_fields = self.fetch_latest(ts_code, trade_date, [f"{table}.{col}" for col in columns])
|
||||
batch_data.setdefault(ts_code, {}).update(latest_fields)
|
||||
except Exception as inner_e:
|
||||
LOGGER.debug(
|
||||
"单条字段查询失败 ts_code=%s table=%s err=%s",
|
||||
ts_code, table, str(inner_e),
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
|
||||
return batch_data
|
||||
|
||||
def check_batch_data_sufficiency(
|
||||
self,
|
||||
ts_codes: List[str],
|
||||
trade_date: str,
|
||||
min_data_count: int = 60,
|
||||
) -> Set[str]:
|
||||
"""批次化检查多个证券的数据充分性
|
||||
|
||||
Args:
|
||||
ts_codes: 证券代码列表
|
||||
trade_date: 交易日
|
||||
min_data_count: 最小数据条数要求
|
||||
|
||||
Returns:
|
||||
数据充分的证券代码集合
|
||||
"""
|
||||
if not ts_codes:
|
||||
return set()
|
||||
|
||||
sufficient_codes = set()
|
||||
|
||||
# 使用IN查询批量检查数据充分性
|
||||
placeholders = ','.join(['?'] * len(ts_codes))
|
||||
query = f"""
|
||||
SELECT ts_code, COUNT(*) as data_count
|
||||
FROM daily
|
||||
WHERE ts_code IN ({placeholders}) AND trade_date <= ?
|
||||
GROUP BY ts_code
|
||||
HAVING COUNT(*) >= ?
|
||||
"""
|
||||
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
rows = conn.execute(query, (*ts_codes, trade_date, min_data_count)).fetchall()
|
||||
for row in rows:
|
||||
ts_code = row['ts_code']
|
||||
sufficient_codes.add(ts_code)
|
||||
except Exception as e:
|
||||
LOGGER.warning(
|
||||
"批次化数据充分性检查失败 err=%s",
|
||||
str(e),
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
# 失败时回退到单条检查
|
||||
for ts_code in ts_codes:
|
||||
if check_data_sufficiency(ts_code, trade_date):
|
||||
sufficient_codes.add(ts_code)
|
||||
|
||||
return sufficient_codes
|
||||
|
||||
def register_refresh_callback(
|
||||
self,
|
||||
start: date | str,
|
||||
@ -422,6 +561,85 @@ class DataBroker:
|
||||
if callback not in bucket:
|
||||
bucket.append(callback)
|
||||
|
||||
def get_news_data(
|
||||
self,
|
||||
ts_code: str,
|
||||
trade_date: str,
|
||||
limit: int = 30
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取新闻数据(简化实现)
|
||||
|
||||
Args:
|
||||
ts_code: 股票代码
|
||||
trade_date: 交易日期
|
||||
limit: 返回的新闻条数限制
|
||||
|
||||
Returns:
|
||||
新闻数据列表,包含sentiment、heat、entities等字段
|
||||
"""
|
||||
# 简化实现:返回模拟数据
|
||||
# 在实际应用中,这里应该查询新闻数据库
|
||||
return [
|
||||
{
|
||||
"sentiment": np.random.uniform(-1, 1),
|
||||
"heat": np.random.uniform(0, 1),
|
||||
"entities": "股票,市场,投资"
|
||||
}
|
||||
for _ in range(min(limit, 5))
|
||||
]
|
||||
|
||||
def _lookup_industry(self, ts_code: str) -> Optional[str]:
|
||||
"""查找股票所属行业
|
||||
|
||||
Args:
|
||||
ts_code: 股票代码
|
||||
|
||||
Returns:
|
||||
行业代码或名称,找不到时返回None
|
||||
"""
|
||||
# 简化实现:返回模拟行业
|
||||
# 在实际应用中,这里应该查询股票行业信息
|
||||
industry_mapping = {
|
||||
"000001.SZ": "银行",
|
||||
"000002.SZ": "房地产",
|
||||
"000858.SZ": "食品饮料",
|
||||
"000962.SZ": "医药生物",
|
||||
}
|
||||
return industry_mapping.get(ts_code, "其他")
|
||||
|
||||
def _derived_industry_sentiment(self, industry: str, trade_date: str) -> Optional[float]:
|
||||
"""计算行业情绪得分
|
||||
|
||||
Args:
|
||||
industry: 行业代码或名称
|
||||
trade_date: 交易日期
|
||||
|
||||
Returns:
|
||||
行业情绪得分,找不到时返回None
|
||||
"""
|
||||
# 简化实现:返回模拟情绪得分
|
||||
# 在实际应用中,这里应该基于行业新闻计算情绪
|
||||
return np.random.uniform(-1, 1)
|
||||
|
||||
def get_industry_stocks(self, industry: str) -> List[str]:
|
||||
"""获取同行业股票列表
|
||||
|
||||
Args:
|
||||
industry: 行业代码或名称
|
||||
|
||||
Returns:
|
||||
同行业股票代码列表
|
||||
"""
|
||||
# 简化实现:返回模拟股票列表
|
||||
# 在实际应用中,这里应该查询行业股票列表
|
||||
industry_stocks = {
|
||||
"银行": ["000001.SZ", "002142.SZ", "600036.SH"],
|
||||
"房地产": ["000002.SZ", "000402.SZ", "600048.SH"],
|
||||
"食品饮料": ["000858.SZ", "600519.SH", "000568.SZ"],
|
||||
"医药生物": ["000962.SZ", "600276.SH", "300003.SZ"],
|
||||
}
|
||||
return industry_stocks.get(industry, [])
|
||||
|
||||
def fetch_flags(
|
||||
self,
|
||||
table: str,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user