From 77e5b932973a2c88e729b652df6590218d7938de Mon Sep 17 00:00:00 2001 From: sam Date: Thu, 9 Oct 2025 10:09:52 +0800 Subject: [PATCH] fix ADX calculation and trade imbalance factor implementation --- app/features/extended_factors.py | 143 ++++++++++++++++++++++--------- app/features/validation.py | 11 ++- test_factor_ranges.py | 55 ------------ 3 files changed, 112 insertions(+), 97 deletions(-) delete mode 100644 test_factor_ranges.py diff --git a/app/features/extended_factors.py b/app/features/extended_factors.py index b3183e0..4f05b16 100644 --- a/app/features/extended_factors.py +++ b/app/features/extended_factors.py @@ -88,11 +88,11 @@ EXTENDED_FACTORS: List[FactorSpec] = [ # 趋势跟踪因子 FactorSpec("trend_ma_cross", 20), # 均线交叉信号 FactorSpec("trend_price_channel", 20), # 价格通道突破 - FactorSpec("trend_adx", 14), # 平均趋向指标 + FactorSpec("trend_adx", 14), # 简化版平均趋向指标(近似) # 市场微观结构因子 FactorSpec("micro_tick_direction", 5), # 逐笔方向 - FactorSpec("micro_trade_imbalance", 10), # 交易失衡 + FactorSpec("micro_trade_imbalance", 10), # 交易不平衡度(基于签名成交量) # 波动率预测因子 FactorSpec("vol_garch", 20), # GARCH波动率 @@ -199,9 +199,11 @@ class ExtendedFactors: return rsi(close_series, 14) elif factor_name == "tech_macd_signal": + # MACD柱状图值(histogram) return macd(close_series, 12, 26, 9) elif factor_name == "tech_bb_position": + # 价格在布林带中的位置(-1到1) return bollinger_bands(close_series, 20) elif factor_name == "tech_obv_momentum": @@ -212,9 +214,13 @@ class ExtendedFactors: # 趋势跟踪因子 elif factor_name == "trend_ma_cross": + # 修复:返回均线交叉的比例而不是差值 ma_5 = rolling_mean(close_series, 5) ma_20 = rolling_mean(close_series, 20) - return ma_5 - ma_20 + if ma_20 is not None and ma_20 != 0: + return (ma_5 - ma_20) / ma_20 if ma_5 is not None else None + else: + return 0.0 if ma_5 is not None else None elif factor_name == "trend_price_channel": # 价格通道突破因子:当前价格相对于通道的位置 @@ -226,39 +232,36 @@ class ExtendedFactors: return 0.0 elif factor_name == "trend_adx": - # 标准ADX计算实现 + # 修复:标准ADX计算实现(需要high/low序列) + # 注意:当前仅使用close序列作为high/low的近似,这是一个简化实现 window = 14 if len(close_series) < window + 1: return None - # 计算+DI和-DI - plus_di = 0 - minus_di = 0 - tr_sum = 0 + # 计算TR、+DM、-DM序列 + tr_series = [] + plus_dm_series = [] + minus_dm_series = [] - # 计算初始TR、+DM、-DM - for i in range(window): - if i + 1 >= len(close_series): - break - - # 计算真实波幅(TR) + for i in range(len(close_series) - 1): + # 使用close作为high/low的近似(简化实现) today_high = close_series[i] today_low = close_series[i] + prev_high = close_series[i + 1] + prev_low = close_series[i + 1] prev_close = close_series[i + 1] + # 计算真实波幅(TR) tr = max( abs(today_high - today_low), abs(today_high - prev_close), abs(today_low - prev_close) ) - tr_sum += tr + tr_series.append(tr) # 计算方向运动 - prev_high = close_series[i + 1] if i + 1 < len(close_series) else close_series[i] - prev_low = close_series[i + 1] if i + 1 < len(close_series) else close_series[i] - - plus_dm = max(0, close_series[i] - prev_high) - minus_dm = max(0, prev_low - close_series[i]) + plus_dm = max(today_high - prev_high, 0) + minus_dm = max(prev_low - today_low, 0) # 确保只有一项为正值 if plus_dm > minus_dm: @@ -267,24 +270,68 @@ class ExtendedFactors: plus_dm = 0 else: plus_dm = minus_dm = 0 - - plus_di += plus_dm - minus_di += minus_dm + + plus_dm_series.append(plus_dm) + minus_dm_series.append(minus_dm) + + if len(tr_series) < window: + return None + + # 计算+DI和-DI(使用Wilder平滑方法) + plus_di_series = [] + minus_di_series = [] + + # 初始化第一个值 + tr_sum = sum(tr_series[:window]) + plus_dm_sum = sum(plus_dm_series[:window]) + minus_dm_sum = sum(minus_dm_series[:window]) - # 计算+DI和-DI if tr_sum > 0: - plus_di = (plus_di / tr_sum) * 100 - minus_di = (minus_di / tr_sum) * 100 + plus_di = 100 * (plus_dm_sum / tr_sum) + minus_di = 100 * (minus_dm_sum / tr_sum) else: plus_di = minus_di = 0 + + plus_di_series.append(plus_di) + minus_di_series.append(minus_di) - # 计算DX - dx = 0 - if plus_di + minus_di > 0: - dx = (abs(plus_di - minus_di) / (plus_di + minus_di)) * 100 + # 计算后续值(使用Wilder平滑) + for i in range(1, len(tr_series) - window + 1): + # Wilder平滑:当前值 = (前一个平滑值 * (n-1) + 当前值) / n + tr_sum = (tr_sum * (window - 1) + tr_series[i + window - 1]) / window + plus_dm_sum = (plus_dm_sum * (window - 1) + plus_dm_series[i + window - 1]) / window + minus_dm_sum = (minus_dm_sum * (window - 1) + minus_dm_series[i + window - 1]) / window + + if tr_sum > 0: + plus_di = 100 * (plus_dm_sum / tr_sum) + minus_di = 100 * (minus_dm_sum / tr_sum) + else: + plus_di = minus_di = 0 + + plus_di_series.append(plus_di) + minus_di_series.append(minus_di) - # ADX是DX的移动平均,这里简化为直接返回DX值,确保在0-100范围内 - return max(0, min(100, dx)) + # 计算DX序列 + dx_series = [] + for i in range(len(plus_di_series)): + plus_di_val = plus_di_series[i] + minus_di_val = minus_di_series[i] + + if plus_di_val + minus_di_val > 0: + dx = 100 * (abs(plus_di_val - minus_di_val) / (plus_di_val + minus_di_val)) + else: + dx = 0 + dx_series.append(dx) + + # 计算ADX(DX的移动平均) + if len(dx_series) < window: + return None + + # ADX是DX的移动平均(使用简单移动平均) + adx = sum(dx_series[:window]) / window + + # 确保在0-100范围内 + return max(0, min(100, adx)) # 市场微观结构因子 elif factor_name == "micro_tick_direction": @@ -298,18 +345,36 @@ class ExtendedFactors: return sum(directions) / window elif factor_name == "micro_trade_imbalance": - # 交易失衡:基于价格和成交量的联合分析 + # 修复:交易失衡(按常规定义实现) + # 基于签名成交量的归一化差值 window = 10 if len(close_series) < window + 1 or len(volume_series) < window + 1: return None - # 计算价格变动和成交量变动 - price_changes = [close_series[i] - close_series[i+1] for i in range(window)] - volume_changes = [volume_series[i] - volume_series[i+1] for i in range(window)] + # 计算签名成交量:volume_t * sign(close_t - close_{t-1}) + signed_volumes = [] + total_volume = 0 - # 计算交易失衡指标 - imbalance = sum(price_changes[i] * volume_changes[i] for i in range(window)) - return imbalance / (window * np.mean(volume_series[:window]) + 1e-8) + for i in range(window): + if i + 1 < len(close_series): + # 计算价格变动符号 + price_change = close_series[i] - close_series[i+1] + sign = 1 if price_change > 0 else (-1 if price_change < 0 else 0) + + # 计算签名成交量 + signed_volume = volume_series[i] * sign + signed_volumes.append(signed_volume) + total_volume += volume_series[i] + + if total_volume == 0: + return 0.0 + + # 计算交易不平衡度:signed_vol / total_volume + signed_vol_sum = sum(signed_volumes) + imbalance = signed_vol_sum / (total_volume + 1e-8) + + # 确保结果在[-1, 1]范围内 + return max(-1.0, min(1.0, imbalance)) # 波动率预测因子 elif factor_name == "vol_garch": diff --git a/app/features/validation.py b/app/features/validation.py index c836262..289cfa1 100644 --- a/app/features/validation.py +++ b/app/features/validation.py @@ -77,6 +77,11 @@ def validate_factor_value( 如果因子值有效则返回原值,否则返回 None """ if value is None: + LOGGER.warning( + "因子值非数值 factor=%s value=%s ts_code=%s date=%s", + name, value, ts_code, trade_date, + extra=LOG_EXTRA + ) return None # 检查是否为有限数值 @@ -98,8 +103,8 @@ def validate_factor_value( "tech_pv_trend": (-1.0, 1.0), # 量价趋势相关性 # 趋势指标精确范围 - "trend_adx": (0, 100.0), # ADX趋势强度 0-100 - "trend_ma_cross": (-1.0, 1.0), # 均线交叉 + "trend_adx": (0, 100.0), # 简化版 ADX(近似)趋势强度 0-100 + "trend_ma_cross": (-1.0, 1.0), # 均线交叉比例 "trend_price_channel": (-1.0, 1.0), # 价格通道位置 # 波动率指标精确范围 @@ -109,7 +114,7 @@ def validate_factor_value( # 微观结构精确范围 "micro_tick_direction": (-1.0, 1.0), # 买卖方向比例 - "micro_trade_imbalance": (-100.0, 100.0), # 交易不平衡度 + "micro_trade_imbalance": (-1.0, 1.0), # 交易不平衡度,范围在[-1, 1]之间 # 情绪指标精确范围 "sent_impact": (0, 1.0), # 情绪影响度 diff --git a/test_factor_ranges.py b/test_factor_ranges.py deleted file mode 100644 index 1db0b79..0000000 --- a/test_factor_ranges.py +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env python3 -"""测试因子值范围验证功能""" - -from app.features.factors import compute_factors -from datetime import date - -def test_factor_ranges(): - """测试因子值范围验证功能""" - print('测试改进后的因子值范围验证功能...') - - try: - results = compute_factors( - date(2024, 1, 15), - ts_codes=['000001.SZ', '000002.SZ'], - skip_existing=False, - batch_size=10 - ) - - print(f'因子计算完成,共计算 {len(results)} 个结果') - - # 检查每个因子的值范围 - valid_count = 0 - invalid_count = 0 - - for result in results: - print(f'\n证券 {result.ts_code} 的因子值:') - for factor_name, value in result.values.items(): - if value is not None: - # 检查值是否在合理范围内 - if -10 <= value <= 10: # 放宽检查范围,主要看验证逻辑 - print(f' ✓ {factor_name}: {value:.6f}') - valid_count += 1 - else: - print(f' ✗ {factor_name}: {value:.6f} (超出范围!)') - invalid_count += 1 - - print(f'\n验证统计:') - print(f' 有效因子值: {valid_count}') - print(f' 无效因子值: {invalid_count}') - print(f' 总因子值: {valid_count + invalid_count}') - - if invalid_count == 0: - print('\n✅ 所有因子值都在合理范围内,验证通过!') - else: - print(f'\n⚠️ 发现 {invalid_count} 个超出范围的因子值,需要进一步优化') - - print('\n✅ 因子值范围验证测试完成') - - except Exception as e: - print(f'❌ 测试失败: {e}') - import traceback - traceback.print_exc() - -if __name__ == "__main__": - test_factor_ranges() \ No newline at end of file