fix ADX calculation and trade imbalance factor implementation
This commit is contained in:
parent
0bd015b525
commit
77e5b93297
@ -88,11 +88,11 @@ EXTENDED_FACTORS: List[FactorSpec] = [
|
||||
# 趋势跟踪因子
|
||||
FactorSpec("trend_ma_cross", 20), # 均线交叉信号
|
||||
FactorSpec("trend_price_channel", 20), # 价格通道突破
|
||||
FactorSpec("trend_adx", 14), # 平均趋向指标
|
||||
FactorSpec("trend_adx", 14), # 简化版平均趋向指标(近似)
|
||||
|
||||
# 市场微观结构因子
|
||||
FactorSpec("micro_tick_direction", 5), # 逐笔方向
|
||||
FactorSpec("micro_trade_imbalance", 10), # 交易失衡
|
||||
FactorSpec("micro_trade_imbalance", 10), # 交易不平衡度(基于签名成交量)
|
||||
|
||||
# 波动率预测因子
|
||||
FactorSpec("vol_garch", 20), # GARCH波动率
|
||||
@ -199,9 +199,11 @@ class ExtendedFactors:
|
||||
return rsi(close_series, 14)
|
||||
|
||||
elif factor_name == "tech_macd_signal":
|
||||
# MACD柱状图值(histogram)
|
||||
return macd(close_series, 12, 26, 9)
|
||||
|
||||
elif factor_name == "tech_bb_position":
|
||||
# 价格在布林带中的位置(-1到1)
|
||||
return bollinger_bands(close_series, 20)
|
||||
|
||||
elif factor_name == "tech_obv_momentum":
|
||||
@ -212,9 +214,13 @@ class ExtendedFactors:
|
||||
|
||||
# 趋势跟踪因子
|
||||
elif factor_name == "trend_ma_cross":
|
||||
# 修复:返回均线交叉的比例而不是差值
|
||||
ma_5 = rolling_mean(close_series, 5)
|
||||
ma_20 = rolling_mean(close_series, 20)
|
||||
return ma_5 - ma_20
|
||||
if ma_20 is not None and ma_20 != 0:
|
||||
return (ma_5 - ma_20) / ma_20 if ma_5 is not None else None
|
||||
else:
|
||||
return 0.0 if ma_5 is not None else None
|
||||
|
||||
elif factor_name == "trend_price_channel":
|
||||
# 价格通道突破因子:当前价格相对于通道的位置
|
||||
@ -226,39 +232,36 @@ class ExtendedFactors:
|
||||
return 0.0
|
||||
|
||||
elif factor_name == "trend_adx":
|
||||
# 标准ADX计算实现
|
||||
# 修复:标准ADX计算实现(需要high/low序列)
|
||||
# 注意:当前仅使用close序列作为high/low的近似,这是一个简化实现
|
||||
window = 14
|
||||
if len(close_series) < window + 1:
|
||||
return None
|
||||
|
||||
# 计算+DI和-DI
|
||||
plus_di = 0
|
||||
minus_di = 0
|
||||
tr_sum = 0
|
||||
# 计算TR、+DM、-DM序列
|
||||
tr_series = []
|
||||
plus_dm_series = []
|
||||
minus_dm_series = []
|
||||
|
||||
# 计算初始TR、+DM、-DM
|
||||
for i in range(window):
|
||||
if i + 1 >= len(close_series):
|
||||
break
|
||||
|
||||
# 计算真实波幅(TR)
|
||||
for i in range(len(close_series) - 1):
|
||||
# 使用close作为high/low的近似(简化实现)
|
||||
today_high = close_series[i]
|
||||
today_low = close_series[i]
|
||||
prev_high = close_series[i + 1]
|
||||
prev_low = close_series[i + 1]
|
||||
prev_close = close_series[i + 1]
|
||||
|
||||
# 计算真实波幅(TR)
|
||||
tr = max(
|
||||
abs(today_high - today_low),
|
||||
abs(today_high - prev_close),
|
||||
abs(today_low - prev_close)
|
||||
)
|
||||
tr_sum += tr
|
||||
tr_series.append(tr)
|
||||
|
||||
# 计算方向运动
|
||||
prev_high = close_series[i + 1] if i + 1 < len(close_series) else close_series[i]
|
||||
prev_low = close_series[i + 1] if i + 1 < len(close_series) else close_series[i]
|
||||
|
||||
plus_dm = max(0, close_series[i] - prev_high)
|
||||
minus_dm = max(0, prev_low - close_series[i])
|
||||
plus_dm = max(today_high - prev_high, 0)
|
||||
minus_dm = max(prev_low - today_low, 0)
|
||||
|
||||
# 确保只有一项为正值
|
||||
if plus_dm > minus_dm:
|
||||
@ -268,23 +271,67 @@ class ExtendedFactors:
|
||||
else:
|
||||
plus_dm = minus_dm = 0
|
||||
|
||||
plus_di += plus_dm
|
||||
minus_di += minus_dm
|
||||
plus_dm_series.append(plus_dm)
|
||||
minus_dm_series.append(minus_dm)
|
||||
|
||||
if len(tr_series) < window:
|
||||
return None
|
||||
|
||||
# 计算+DI和-DI(使用Wilder平滑方法)
|
||||
plus_di_series = []
|
||||
minus_di_series = []
|
||||
|
||||
# 初始化第一个值
|
||||
tr_sum = sum(tr_series[:window])
|
||||
plus_dm_sum = sum(plus_dm_series[:window])
|
||||
minus_dm_sum = sum(minus_dm_series[:window])
|
||||
|
||||
# 计算+DI和-DI
|
||||
if tr_sum > 0:
|
||||
plus_di = (plus_di / tr_sum) * 100
|
||||
minus_di = (minus_di / tr_sum) * 100
|
||||
plus_di = 100 * (plus_dm_sum / tr_sum)
|
||||
minus_di = 100 * (minus_dm_sum / tr_sum)
|
||||
else:
|
||||
plus_di = minus_di = 0
|
||||
|
||||
# 计算DX
|
||||
dx = 0
|
||||
if plus_di + minus_di > 0:
|
||||
dx = (abs(plus_di - minus_di) / (plus_di + minus_di)) * 100
|
||||
plus_di_series.append(plus_di)
|
||||
minus_di_series.append(minus_di)
|
||||
|
||||
# ADX是DX的移动平均,这里简化为直接返回DX值,确保在0-100范围内
|
||||
return max(0, min(100, dx))
|
||||
# 计算后续值(使用Wilder平滑)
|
||||
for i in range(1, len(tr_series) - window + 1):
|
||||
# Wilder平滑:当前值 = (前一个平滑值 * (n-1) + 当前值) / n
|
||||
tr_sum = (tr_sum * (window - 1) + tr_series[i + window - 1]) / window
|
||||
plus_dm_sum = (plus_dm_sum * (window - 1) + plus_dm_series[i + window - 1]) / window
|
||||
minus_dm_sum = (minus_dm_sum * (window - 1) + minus_dm_series[i + window - 1]) / window
|
||||
|
||||
if tr_sum > 0:
|
||||
plus_di = 100 * (plus_dm_sum / tr_sum)
|
||||
minus_di = 100 * (minus_dm_sum / tr_sum)
|
||||
else:
|
||||
plus_di = minus_di = 0
|
||||
|
||||
plus_di_series.append(plus_di)
|
||||
minus_di_series.append(minus_di)
|
||||
|
||||
# 计算DX序列
|
||||
dx_series = []
|
||||
for i in range(len(plus_di_series)):
|
||||
plus_di_val = plus_di_series[i]
|
||||
minus_di_val = minus_di_series[i]
|
||||
|
||||
if plus_di_val + minus_di_val > 0:
|
||||
dx = 100 * (abs(plus_di_val - minus_di_val) / (plus_di_val + minus_di_val))
|
||||
else:
|
||||
dx = 0
|
||||
dx_series.append(dx)
|
||||
|
||||
# 计算ADX(DX的移动平均)
|
||||
if len(dx_series) < window:
|
||||
return None
|
||||
|
||||
# ADX是DX的移动平均(使用简单移动平均)
|
||||
adx = sum(dx_series[:window]) / window
|
||||
|
||||
# 确保在0-100范围内
|
||||
return max(0, min(100, adx))
|
||||
|
||||
# 市场微观结构因子
|
||||
elif factor_name == "micro_tick_direction":
|
||||
@ -298,18 +345,36 @@ class ExtendedFactors:
|
||||
return sum(directions) / window
|
||||
|
||||
elif factor_name == "micro_trade_imbalance":
|
||||
# 交易失衡:基于价格和成交量的联合分析
|
||||
# 修复:交易失衡(按常规定义实现)
|
||||
# 基于签名成交量的归一化差值
|
||||
window = 10
|
||||
if len(close_series) < window + 1 or len(volume_series) < window + 1:
|
||||
return None
|
||||
|
||||
# 计算价格变动和成交量变动
|
||||
price_changes = [close_series[i] - close_series[i+1] for i in range(window)]
|
||||
volume_changes = [volume_series[i] - volume_series[i+1] for i in range(window)]
|
||||
# 计算签名成交量:volume_t * sign(close_t - close_{t-1})
|
||||
signed_volumes = []
|
||||
total_volume = 0
|
||||
|
||||
# 计算交易失衡指标
|
||||
imbalance = sum(price_changes[i] * volume_changes[i] for i in range(window))
|
||||
return imbalance / (window * np.mean(volume_series[:window]) + 1e-8)
|
||||
for i in range(window):
|
||||
if i + 1 < len(close_series):
|
||||
# 计算价格变动符号
|
||||
price_change = close_series[i] - close_series[i+1]
|
||||
sign = 1 if price_change > 0 else (-1 if price_change < 0 else 0)
|
||||
|
||||
# 计算签名成交量
|
||||
signed_volume = volume_series[i] * sign
|
||||
signed_volumes.append(signed_volume)
|
||||
total_volume += volume_series[i]
|
||||
|
||||
if total_volume == 0:
|
||||
return 0.0
|
||||
|
||||
# 计算交易不平衡度:signed_vol / total_volume
|
||||
signed_vol_sum = sum(signed_volumes)
|
||||
imbalance = signed_vol_sum / (total_volume + 1e-8)
|
||||
|
||||
# 确保结果在[-1, 1]范围内
|
||||
return max(-1.0, min(1.0, imbalance))
|
||||
|
||||
# 波动率预测因子
|
||||
elif factor_name == "vol_garch":
|
||||
|
||||
@ -77,6 +77,11 @@ def validate_factor_value(
|
||||
如果因子值有效则返回原值,否则返回 None
|
||||
"""
|
||||
if value is None:
|
||||
LOGGER.warning(
|
||||
"因子值非数值 factor=%s value=%s ts_code=%s date=%s",
|
||||
name, value, ts_code, trade_date,
|
||||
extra=LOG_EXTRA
|
||||
)
|
||||
return None
|
||||
|
||||
# 检查是否为有限数值
|
||||
@ -98,8 +103,8 @@ def validate_factor_value(
|
||||
"tech_pv_trend": (-1.0, 1.0), # 量价趋势相关性
|
||||
|
||||
# 趋势指标精确范围
|
||||
"trend_adx": (0, 100.0), # ADX趋势强度 0-100
|
||||
"trend_ma_cross": (-1.0, 1.0), # 均线交叉
|
||||
"trend_adx": (0, 100.0), # 简化版 ADX(近似)趋势强度 0-100
|
||||
"trend_ma_cross": (-1.0, 1.0), # 均线交叉比例
|
||||
"trend_price_channel": (-1.0, 1.0), # 价格通道位置
|
||||
|
||||
# 波动率指标精确范围
|
||||
@ -109,7 +114,7 @@ def validate_factor_value(
|
||||
|
||||
# 微观结构精确范围
|
||||
"micro_tick_direction": (-1.0, 1.0), # 买卖方向比例
|
||||
"micro_trade_imbalance": (-100.0, 100.0), # 交易不平衡度
|
||||
"micro_trade_imbalance": (-1.0, 1.0), # 交易不平衡度,范围在[-1, 1]之间
|
||||
|
||||
# 情绪指标精确范围
|
||||
"sent_impact": (0, 1.0), # 情绪影响度
|
||||
|
||||
@ -1,55 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""测试因子值范围验证功能"""
|
||||
|
||||
from app.features.factors import compute_factors
|
||||
from datetime import date
|
||||
|
||||
def test_factor_ranges():
|
||||
"""测试因子值范围验证功能"""
|
||||
print('测试改进后的因子值范围验证功能...')
|
||||
|
||||
try:
|
||||
results = compute_factors(
|
||||
date(2024, 1, 15),
|
||||
ts_codes=['000001.SZ', '000002.SZ'],
|
||||
skip_existing=False,
|
||||
batch_size=10
|
||||
)
|
||||
|
||||
print(f'因子计算完成,共计算 {len(results)} 个结果')
|
||||
|
||||
# 检查每个因子的值范围
|
||||
valid_count = 0
|
||||
invalid_count = 0
|
||||
|
||||
for result in results:
|
||||
print(f'\n证券 {result.ts_code} 的因子值:')
|
||||
for factor_name, value in result.values.items():
|
||||
if value is not None:
|
||||
# 检查值是否在合理范围内
|
||||
if -10 <= value <= 10: # 放宽检查范围,主要看验证逻辑
|
||||
print(f' ✓ {factor_name}: {value:.6f}')
|
||||
valid_count += 1
|
||||
else:
|
||||
print(f' ✗ {factor_name}: {value:.6f} (超出范围!)')
|
||||
invalid_count += 1
|
||||
|
||||
print(f'\n验证统计:')
|
||||
print(f' 有效因子值: {valid_count}')
|
||||
print(f' 无效因子值: {invalid_count}')
|
||||
print(f' 总因子值: {valid_count + invalid_count}')
|
||||
|
||||
if invalid_count == 0:
|
||||
print('\n✅ 所有因子值都在合理范围内,验证通过!')
|
||||
else:
|
||||
print(f'\n⚠️ 发现 {invalid_count} 个超出范围的因子值,需要进一步优化')
|
||||
|
||||
print('\n✅ 因子值范围验证测试完成')
|
||||
|
||||
except Exception as e:
|
||||
print(f'❌ 测试失败: {e}')
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_factor_ranges()
|
||||
Loading…
Reference in New Issue
Block a user