fix ADX calculation and trade imbalance factor implementation

This commit is contained in:
sam 2025-10-09 10:09:52 +08:00
parent 0bd015b525
commit 77e5b93297
3 changed files with 112 additions and 97 deletions

View File

@ -88,11 +88,11 @@ EXTENDED_FACTORS: List[FactorSpec] = [
# 趋势跟踪因子
FactorSpec("trend_ma_cross", 20), # 均线交叉信号
FactorSpec("trend_price_channel", 20), # 价格通道突破
FactorSpec("trend_adx", 14), # 平均趋向指标
FactorSpec("trend_adx", 14), # 简化版平均趋向指标(近似)
# 市场微观结构因子
FactorSpec("micro_tick_direction", 5), # 逐笔方向
FactorSpec("micro_trade_imbalance", 10), # 交易失衡
FactorSpec("micro_trade_imbalance", 10), # 交易不平衡度(基于签名成交量)
# 波动率预测因子
FactorSpec("vol_garch", 20), # GARCH波动率
@ -199,9 +199,11 @@ class ExtendedFactors:
return rsi(close_series, 14)
elif factor_name == "tech_macd_signal":
# MACD柱状图值histogram
return macd(close_series, 12, 26, 9)
elif factor_name == "tech_bb_position":
# 价格在布林带中的位置(-1到1)
return bollinger_bands(close_series, 20)
elif factor_name == "tech_obv_momentum":
@ -212,9 +214,13 @@ class ExtendedFactors:
# 趋势跟踪因子
elif factor_name == "trend_ma_cross":
# 修复:返回均线交叉的比例而不是差值
ma_5 = rolling_mean(close_series, 5)
ma_20 = rolling_mean(close_series, 20)
return ma_5 - ma_20
if ma_20 is not None and ma_20 != 0:
return (ma_5 - ma_20) / ma_20 if ma_5 is not None else None
else:
return 0.0 if ma_5 is not None else None
elif factor_name == "trend_price_channel":
# 价格通道突破因子:当前价格相对于通道的位置
@ -226,39 +232,36 @@ class ExtendedFactors:
return 0.0
elif factor_name == "trend_adx":
# 标准ADX计算实现
# 修复标准ADX计算实现需要high/low序列
# 注意当前仅使用close序列作为high/low的近似这是一个简化实现
window = 14
if len(close_series) < window + 1:
return None
# 计算+DI和-DI
plus_di = 0
minus_di = 0
tr_sum = 0
# 计算TR、+DM、-DM序列
tr_series = []
plus_dm_series = []
minus_dm_series = []
# 计算初始TR、+DM、-DM
for i in range(window):
if i + 1 >= len(close_series):
break
# 计算真实波幅(TR)
for i in range(len(close_series) - 1):
# 使用close作为high/low的近似简化实现
today_high = close_series[i]
today_low = close_series[i]
prev_high = close_series[i + 1]
prev_low = close_series[i + 1]
prev_close = close_series[i + 1]
# 计算真实波幅(TR)
tr = max(
abs(today_high - today_low),
abs(today_high - prev_close),
abs(today_low - prev_close)
)
tr_sum += tr
tr_series.append(tr)
# 计算方向运动
prev_high = close_series[i + 1] if i + 1 < len(close_series) else close_series[i]
prev_low = close_series[i + 1] if i + 1 < len(close_series) else close_series[i]
plus_dm = max(0, close_series[i] - prev_high)
minus_dm = max(0, prev_low - close_series[i])
plus_dm = max(today_high - prev_high, 0)
minus_dm = max(prev_low - today_low, 0)
# 确保只有一项为正值
if plus_dm > minus_dm:
@ -267,24 +270,68 @@ class ExtendedFactors:
plus_dm = 0
else:
plus_dm = minus_dm = 0
plus_di += plus_dm
minus_di += minus_dm
plus_dm_series.append(plus_dm)
minus_dm_series.append(minus_dm)
if len(tr_series) < window:
return None
# 计算+DI和-DI使用Wilder平滑方法
plus_di_series = []
minus_di_series = []
# 初始化第一个值
tr_sum = sum(tr_series[:window])
plus_dm_sum = sum(plus_dm_series[:window])
minus_dm_sum = sum(minus_dm_series[:window])
# 计算+DI和-DI
if tr_sum > 0:
plus_di = (plus_di / tr_sum) * 100
minus_di = (minus_di / tr_sum) * 100
plus_di = 100 * (plus_dm_sum / tr_sum)
minus_di = 100 * (minus_dm_sum / tr_sum)
else:
plus_di = minus_di = 0
plus_di_series.append(plus_di)
minus_di_series.append(minus_di)
# 计算DX
dx = 0
if plus_di + minus_di > 0:
dx = (abs(plus_di - minus_di) / (plus_di + minus_di)) * 100
# 计算后续值使用Wilder平滑
for i in range(1, len(tr_series) - window + 1):
# Wilder平滑当前值 = (前一个平滑值 * (n-1) + 当前值) / n
tr_sum = (tr_sum * (window - 1) + tr_series[i + window - 1]) / window
plus_dm_sum = (plus_dm_sum * (window - 1) + plus_dm_series[i + window - 1]) / window
minus_dm_sum = (minus_dm_sum * (window - 1) + minus_dm_series[i + window - 1]) / window
if tr_sum > 0:
plus_di = 100 * (plus_dm_sum / tr_sum)
minus_di = 100 * (minus_dm_sum / tr_sum)
else:
plus_di = minus_di = 0
plus_di_series.append(plus_di)
minus_di_series.append(minus_di)
# ADX是DX的移动平均这里简化为直接返回DX值确保在0-100范围内
return max(0, min(100, dx))
# 计算DX序列
dx_series = []
for i in range(len(plus_di_series)):
plus_di_val = plus_di_series[i]
minus_di_val = minus_di_series[i]
if plus_di_val + minus_di_val > 0:
dx = 100 * (abs(plus_di_val - minus_di_val) / (plus_di_val + minus_di_val))
else:
dx = 0
dx_series.append(dx)
# 计算ADXDX的移动平均
if len(dx_series) < window:
return None
# ADX是DX的移动平均使用简单移动平均
adx = sum(dx_series[:window]) / window
# 确保在0-100范围内
return max(0, min(100, adx))
# 市场微观结构因子
elif factor_name == "micro_tick_direction":
@ -298,18 +345,36 @@ class ExtendedFactors:
return sum(directions) / window
elif factor_name == "micro_trade_imbalance":
# 交易失衡:基于价格和成交量的联合分析
# 修复:交易失衡(按常规定义实现)
# 基于签名成交量的归一化差值
window = 10
if len(close_series) < window + 1 or len(volume_series) < window + 1:
return None
# 计算价格变动和成交量变动
price_changes = [close_series[i] - close_series[i+1] for i in range(window)]
volume_changes = [volume_series[i] - volume_series[i+1] for i in range(window)]
# 计算签名成交量volume_t * sign(close_t - close_{t-1})
signed_volumes = []
total_volume = 0
# 计算交易失衡指标
imbalance = sum(price_changes[i] * volume_changes[i] for i in range(window))
return imbalance / (window * np.mean(volume_series[:window]) + 1e-8)
for i in range(window):
if i + 1 < len(close_series):
# 计算价格变动符号
price_change = close_series[i] - close_series[i+1]
sign = 1 if price_change > 0 else (-1 if price_change < 0 else 0)
# 计算签名成交量
signed_volume = volume_series[i] * sign
signed_volumes.append(signed_volume)
total_volume += volume_series[i]
if total_volume == 0:
return 0.0
# 计算交易不平衡度signed_vol / total_volume
signed_vol_sum = sum(signed_volumes)
imbalance = signed_vol_sum / (total_volume + 1e-8)
# 确保结果在[-1, 1]范围内
return max(-1.0, min(1.0, imbalance))
# 波动率预测因子
elif factor_name == "vol_garch":

View File

@ -77,6 +77,11 @@ def validate_factor_value(
如果因子值有效则返回原值否则返回 None
"""
if value is None:
LOGGER.warning(
"因子值非数值 factor=%s value=%s ts_code=%s date=%s",
name, value, ts_code, trade_date,
extra=LOG_EXTRA
)
return None
# 检查是否为有限数值
@ -98,8 +103,8 @@ def validate_factor_value(
"tech_pv_trend": (-1.0, 1.0), # 量价趋势相关性
# 趋势指标精确范围
"trend_adx": (0, 100.0), # ADX趋势强度 0-100
"trend_ma_cross": (-1.0, 1.0), # 均线交叉
"trend_adx": (0, 100.0), # 简化版 ADX(近似)趋势强度 0-100
"trend_ma_cross": (-1.0, 1.0), # 均线交叉比例
"trend_price_channel": (-1.0, 1.0), # 价格通道位置
# 波动率指标精确范围
@ -109,7 +114,7 @@ def validate_factor_value(
# 微观结构精确范围
"micro_tick_direction": (-1.0, 1.0), # 买卖方向比例
"micro_trade_imbalance": (-100.0, 100.0), # 交易不平衡度
"micro_trade_imbalance": (-1.0, 1.0), # 交易不平衡度,范围在[-1, 1]之间
# 情绪指标精确范围
"sent_impact": (0, 1.0), # 情绪影响度

View File

@ -1,55 +0,0 @@
#!/usr/bin/env python3
"""测试因子值范围验证功能"""
from app.features.factors import compute_factors
from datetime import date
def test_factor_ranges():
"""测试因子值范围验证功能"""
print('测试改进后的因子值范围验证功能...')
try:
results = compute_factors(
date(2024, 1, 15),
ts_codes=['000001.SZ', '000002.SZ'],
skip_existing=False,
batch_size=10
)
print(f'因子计算完成,共计算 {len(results)} 个结果')
# 检查每个因子的值范围
valid_count = 0
invalid_count = 0
for result in results:
print(f'\n证券 {result.ts_code} 的因子值:')
for factor_name, value in result.values.items():
if value is not None:
# 检查值是否在合理范围内
if -10 <= value <= 10: # 放宽检查范围,主要看验证逻辑
print(f'{factor_name}: {value:.6f}')
valid_count += 1
else:
print(f'{factor_name}: {value:.6f} (超出范围!)')
invalid_count += 1
print(f'\n验证统计:')
print(f' 有效因子值: {valid_count}')
print(f' 无效因子值: {invalid_count}')
print(f' 总因子值: {valid_count + invalid_count}')
if invalid_count == 0:
print('\n✅ 所有因子值都在合理范围内,验证通过!')
else:
print(f'\n⚠️ 发现 {invalid_count} 个超出范围的因子值,需要进一步优化')
print('\n✅ 因子值范围验证测试完成')
except Exception as e:
print(f'❌ 测试失败: {e}')
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_factor_ranges()