enhance ADX calculation and adjust factor validation ranges
This commit is contained in:
parent
db0afe9c2d
commit
f1ded59dce
@ -226,22 +226,65 @@ class ExtendedFactors:
|
||||
return 0.0
|
||||
|
||||
elif factor_name == "trend_adx":
|
||||
# 简化的ADX计算:基于价格变动方向
|
||||
# 标准ADX计算实现
|
||||
window = 14
|
||||
if len(close_series) < window + 1:
|
||||
return None
|
||||
|
||||
# 计算价格变动
|
||||
price_changes = [close_series[i] - close_series[i+1] for i in range(window)]
|
||||
# 计算+DI和-DI
|
||||
plus_di = 0
|
||||
minus_di = 0
|
||||
tr_sum = 0
|
||||
|
||||
# 计算正向和负向变动
|
||||
pos_moves = sum(max(0, change) for change in price_changes)
|
||||
neg_moves = sum(max(0, -change) for change in price_changes)
|
||||
# 计算初始TR、+DM、-DM
|
||||
for i in range(window):
|
||||
if i + 1 >= len(close_series):
|
||||
break
|
||||
|
||||
# 简化的ADX计算
|
||||
if pos_moves + neg_moves > 0:
|
||||
return (pos_moves - neg_moves) / (pos_moves + neg_moves)
|
||||
return 0.0
|
||||
# 计算真实波幅(TR)
|
||||
today_high = close_series[i]
|
||||
today_low = close_series[i]
|
||||
prev_close = close_series[i + 1]
|
||||
|
||||
tr = max(
|
||||
abs(today_high - today_low),
|
||||
abs(today_high - prev_close),
|
||||
abs(today_low - prev_close)
|
||||
)
|
||||
tr_sum += tr
|
||||
|
||||
# 计算方向运动
|
||||
prev_high = close_series[i + 1] if i + 1 < len(close_series) else close_series[i]
|
||||
prev_low = close_series[i + 1] if i + 1 < len(close_series) else close_series[i]
|
||||
|
||||
plus_dm = max(0, close_series[i] - prev_high)
|
||||
minus_dm = max(0, prev_low - close_series[i])
|
||||
|
||||
# 确保只有一项为正值
|
||||
if plus_dm > minus_dm:
|
||||
minus_dm = 0
|
||||
elif minus_dm > plus_dm:
|
||||
plus_dm = 0
|
||||
else:
|
||||
plus_dm = minus_dm = 0
|
||||
|
||||
plus_di += plus_dm
|
||||
minus_di += minus_dm
|
||||
|
||||
# 计算+DI和-DI
|
||||
if tr_sum > 0:
|
||||
plus_di = (plus_di / tr_sum) * 100
|
||||
minus_di = (minus_di / tr_sum) * 100
|
||||
else:
|
||||
plus_di = minus_di = 0
|
||||
|
||||
# 计算DX
|
||||
dx = 0
|
||||
if plus_di + minus_di > 0:
|
||||
dx = (abs(plus_di - minus_di) / (plus_di + minus_di)) * 100
|
||||
|
||||
# ADX是DX的移动平均,这里简化为直接返回DX值,确保在0-100范围内
|
||||
return max(0, min(100, dx))
|
||||
|
||||
# 市场微观结构因子
|
||||
elif factor_name == "micro_tick_direction":
|
||||
|
||||
@ -11,15 +11,15 @@ LOG_EXTRA = {"stage": "factor_validation"}
|
||||
# 因子值范围限制配置 - 基于物理规律和实际数据特征
|
||||
FACTOR_LIMITS = {
|
||||
# 动量类因子:收益率相关,限制在 ±50% (实际收益率很少超过这个范围)
|
||||
"mom_": (-0.5, 0.5),
|
||||
"momentum_": (-0.5, 0.5),
|
||||
"mom_": (-50.0, 50.0),
|
||||
"momentum_": (-50.0, 50.0),
|
||||
|
||||
# 波动率类因子:年化波动率,限制在 0-200% (考虑极端市场情况)
|
||||
"volat_": (0, 2.0),
|
||||
"vol_": (0, 2.0),
|
||||
"volat_": (0, 200.0),
|
||||
"vol_": (0, 200.0),
|
||||
|
||||
# 换手率类因子:日换手率,限制在 0-100% (实际换手率通常在这个范围内)
|
||||
"turn_": (0, 1.0),
|
||||
"turn_": (0, 200.0),
|
||||
|
||||
# 估值评分类因子:标准化评分,限制在 -3到3 (Z-score标准化范围)
|
||||
"val_": (-3.0, 3.0),
|
||||
@ -33,14 +33,14 @@ FACTOR_LIMITS = {
|
||||
|
||||
# 技术指标类因子:具体技术指标的范围限制
|
||||
"tech_rsi": (0, 100.0), # RSI指标范围 0-100
|
||||
"tech_macd": (-0.5, 0.5), # MACD信号范围
|
||||
"tech_macd": (-5.0, 5.0), # MACD信号范围
|
||||
"tech_bb": (-3.0, 3.0), # 布林带位置,标准差倍数
|
||||
"tech_obv": (-10.0, 10.0), # OBV动量标准化
|
||||
"tech_pv": (-1.0, 1.0), # 量价趋势相关性
|
||||
|
||||
# 趋势类因子:趋势强度指标
|
||||
"trend_": (-3.0, 3.0),
|
||||
"trend_ma": (-0.5, 0.5), # 均线交叉
|
||||
"trend_ma": (-1.0, 1.0), # 均线交叉
|
||||
"trend_adx": (0, 100.0), # ADX趋势强度 0-100
|
||||
|
||||
# 微观结构类因子:标准化微观指标
|
||||
@ -92,7 +92,7 @@ def validate_factor_value(
|
||||
exact_matches = {
|
||||
# 技术指标精确范围
|
||||
"tech_rsi_14": (0, 100.0), # RSI指标范围 0-100
|
||||
"tech_macd_signal": (-0.5, 0.5), # MACD信号范围
|
||||
"tech_macd_signal": (-5, 5), # MACD信号范围
|
||||
"tech_bb_position": (-3.0, 3.0), # 布林带位置,标准差倍数
|
||||
"tech_obv_momentum": (-10.0, 10.0), # OBV动量标准化
|
||||
"tech_pv_trend": (-1.0, 1.0), # 量价趋势相关性
|
||||
@ -100,11 +100,11 @@ def validate_factor_value(
|
||||
# 趋势指标精确范围
|
||||
"trend_adx": (0, 100.0), # ADX趋势强度 0-100
|
||||
"trend_ma_cross": (-1.0, 1.0), # 均线交叉
|
||||
"trend_price_channel": (0, 1.0), # 价格通道位置
|
||||
"trend_price_channel": (-1.0, 1.0), # 价格通道位置
|
||||
|
||||
# 波动率指标精确范围
|
||||
"vol_garch": (0, 0.5), # GARCH波动率预测,限制在50%以内
|
||||
"vol_range_pred": (0, 0.2), # 波动率范围预测,限制在20%以内
|
||||
"vol_garch": (0, 50), # GARCH波动率预测,限制在50%以内
|
||||
"vol_range_pred": (0, 20), # 波动率范围预测,限制在20%以内
|
||||
"vol_regime": (0, 1.0), # 波动率状态,0-1之间
|
||||
|
||||
# 微观结构精确范围
|
||||
@ -114,6 +114,7 @@ def validate_factor_value(
|
||||
# 情绪指标精确范围
|
||||
"sent_impact": (0, 1.0), # 情绪影响度
|
||||
"sent_divergence": (-1.0, 1.0), # 情绪分歧度
|
||||
"volume_price_diverge": (-1.0, 1.0), # 量价背离度
|
||||
}
|
||||
|
||||
# 检查精确匹配
|
||||
|
||||
111
tests/simple_test_trend_adx.py
Normal file
111
tests/simple_test_trend_adx.py
Normal file
@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""简化版trend_adx因子测试"""
|
||||
|
||||
def compute_trend_adx(close_series):
|
||||
"""简化版trend_adx计算实现"""
|
||||
# 标准ADX计算实现
|
||||
window = 14
|
||||
if len(close_series) < window + 1:
|
||||
return None
|
||||
|
||||
# 计算+DI和-DI
|
||||
plus_di = 0
|
||||
minus_di = 0
|
||||
tr_sum = 0
|
||||
|
||||
# 计算初始TR、+DM、-DM
|
||||
for i in range(window):
|
||||
if i + 1 >= len(close_series):
|
||||
break
|
||||
|
||||
# 计算真实波幅(TR)
|
||||
today_high = close_series[i]
|
||||
today_low = close_series[i]
|
||||
prev_close = close_series[i + 1]
|
||||
|
||||
tr = max(
|
||||
abs(today_high - today_low),
|
||||
abs(today_high - prev_close),
|
||||
abs(today_low - prev_close)
|
||||
)
|
||||
tr_sum += tr
|
||||
|
||||
# 计算方向运动
|
||||
prev_high = close_series[i + 1] if i + 1 < len(close_series) else close_series[i]
|
||||
prev_low = close_series[i + 1] if i + 1 < len(close_series) else close_series[i]
|
||||
|
||||
plus_dm = max(0, close_series[i] - prev_high)
|
||||
minus_dm = max(0, prev_low - close_series[i])
|
||||
|
||||
# 确保只有一项为正值
|
||||
if plus_dm > minus_dm:
|
||||
minus_dm = 0
|
||||
elif minus_dm > plus_dm:
|
||||
plus_dm = 0
|
||||
else:
|
||||
plus_dm = minus_dm = 0
|
||||
|
||||
plus_di += plus_dm
|
||||
minus_di += minus_dm
|
||||
|
||||
# 计算+DI和-DI
|
||||
if tr_sum > 0:
|
||||
plus_di = (plus_di / tr_sum) * 100
|
||||
minus_di = (minus_di / tr_sum) * 100
|
||||
else:
|
||||
plus_di = minus_di = 0
|
||||
|
||||
# 计算DX
|
||||
dx = 0
|
||||
if plus_di + minus_di > 0:
|
||||
dx = (abs(plus_di - minus_di) / (plus_di + minus_di)) * 100
|
||||
|
||||
# ADX是DX的移动平均,这里简化为直接返回DX值,确保在0-100范围内
|
||||
return max(0, min(100, dx))
|
||||
|
||||
|
||||
def test_trend_adx():
|
||||
"""测试trend_adx因子计算"""
|
||||
|
||||
# 测试上涨趋势
|
||||
print("测试上涨趋势:")
|
||||
close_prices_up = [115, 114, 113, 112, 111, 110, 109, 108, 107, 106, 105, 104, 103, 102, 101, 100]
|
||||
result_up = compute_trend_adx(close_prices_up)
|
||||
print(f"上涨趋势trend_adx值: {result_up}")
|
||||
assert result_up is not None and 0 <= result_up <= 100, f"上涨趋势结果错误: {result_up}"
|
||||
|
||||
# 测试下跌趋势
|
||||
print("\n测试下跌趋势:")
|
||||
close_prices_down = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115]
|
||||
result_down = compute_trend_adx(close_prices_down)
|
||||
print(f"下跌趋势trend_adx值: {result_down}")
|
||||
assert result_down is not None and 0 <= result_down <= 100, f"下跌趋势结果错误: {result_down}"
|
||||
|
||||
# 测试震荡市场
|
||||
print("\n测试震荡市场:")
|
||||
close_prices_sideways = [100, 100.5, 99.5, 101, 99, 100.5, 99.5, 100, 100.5, 99.5, 100, 100.5, 99.5, 100, 100.5, 99.5]
|
||||
result_sideways = compute_trend_adx(close_prices_sideways)
|
||||
print(f"震荡市场trend_adx值: {result_sideways}")
|
||||
assert result_sideways is not None and 0 <= result_sideways <= 100, f"震荡市场结果错误: {result_sideways}"
|
||||
|
||||
# 测试数据不足
|
||||
print("\n测试数据不足:")
|
||||
close_prices_insufficient = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]
|
||||
result_insufficient = compute_trend_adx(close_prices_insufficient)
|
||||
print(f"数据不足时结果: {result_insufficient}")
|
||||
assert result_insufficient is None, f"数据不足时应该返回None: {result_insufficient}"
|
||||
|
||||
# 测试平盘市场
|
||||
print("\n测试平盘市场:")
|
||||
close_prices_flat = [100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100]
|
||||
result_flat = compute_trend_adx(close_prices_flat)
|
||||
print(f"平盘市场trend_adx值: {result_flat}")
|
||||
assert result_flat is not None and 0 <= result_flat <= 100, f"平盘市场结果错误: {result_flat}"
|
||||
|
||||
print("\n所有测试通过!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_trend_adx()
|
||||
92
tests/test_trend_adx.py
Normal file
92
tests/test_trend_adx.py
Normal file
@ -0,0 +1,92 @@
|
||||
import unittest
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from app.features.extended_factors import ExtendedFactorEngine
|
||||
|
||||
|
||||
class TestTrendAdx(unittest.TestCase):
|
||||
"""测试trend_adx因子计算"""
|
||||
|
||||
def setUp(self):
|
||||
"""初始化测试环境"""
|
||||
self.engine = ExtendedFactorEngine()
|
||||
|
||||
def test_trend_adx_positive_values(self):
|
||||
"""测试trend_adx因子返回正值"""
|
||||
# 模拟一个上涨趋势的数据
|
||||
close_prices = [115, 114, 113, 112, 111, 110, 109, 108, 107, 106, 105, 104, 103, 102, 101, 100]
|
||||
volume_prices = [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]
|
||||
|
||||
result = self.engine.compute_factor("trend_adx", close_prices, volume_prices)
|
||||
|
||||
# 验证结果不为None
|
||||
self.assertIsNotNone(result)
|
||||
# 验证结果在0-100范围内
|
||||
self.assertGreaterEqual(result, 0)
|
||||
self.assertLessEqual(result, 100)
|
||||
print(f"上涨趋势trend_adx值: {result}")
|
||||
|
||||
def test_trend_adx_negative_values(self):
|
||||
"""测试trend_adx因子处理下跌趋势"""
|
||||
# 模拟一个下跌趋势的数据
|
||||
close_prices = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115]
|
||||
volume_prices = [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]
|
||||
|
||||
result = self.engine.compute_factor("trend_adx", close_prices, volume_prices)
|
||||
|
||||
# 验证结果不为None
|
||||
self.assertIsNotNone(result)
|
||||
# 验证结果在0-100范围内
|
||||
self.assertGreaterEqual(result, 0)
|
||||
self.assertLessEqual(result, 100)
|
||||
print(f"下跌趋势trend_adx值: {result}")
|
||||
|
||||
def test_trend_adx_sideways_market(self):
|
||||
"""测试trend_adx因子处理震荡市场"""
|
||||
# 模拟一个震荡市场的数据
|
||||
close_prices = [100, 100.5, 99.5, 101, 99, 100.5, 99.5, 100, 100.5, 99.5, 100, 100.5, 99.5, 100, 100.5, 99.5]
|
||||
volume_prices = [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]
|
||||
|
||||
result = self.engine.compute_factor("trend_adx", close_prices, volume_prices)
|
||||
|
||||
# 验证结果不为None
|
||||
self.assertIsNotNone(result)
|
||||
# 验证结果在0-100范围内
|
||||
self.assertGreaterEqual(result, 0)
|
||||
self.assertLessEqual(result, 100)
|
||||
print(f"震荡市场trend_adx值: {result}")
|
||||
|
||||
def test_trend_adx_insufficient_data(self):
|
||||
"""测试数据不足时返回None"""
|
||||
# 提供少于15个数据点
|
||||
close_prices = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]
|
||||
volume_prices = [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]
|
||||
|
||||
result = self.engine.compute_factor("trend_adx", close_prices, volume_prices)
|
||||
|
||||
# 验证结果为None
|
||||
self.assertIsNone(result)
|
||||
print("数据不足时正确返回None")
|
||||
|
||||
def test_trend_adx_flat_market(self):
|
||||
"""测试trend_adx因子处理平盘市场"""
|
||||
# 模拟一个价格保持不变的市场
|
||||
close_prices = [100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100]
|
||||
volume_prices = [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]
|
||||
|
||||
result = self.engine.compute_factor("trend_adx", close_prices, volume_prices)
|
||||
|
||||
# 验证结果不为None
|
||||
self.assertIsNotNone(result)
|
||||
# 验证结果在0-100范围内
|
||||
self.assertGreaterEqual(result, 0)
|
||||
self.assertLessEqual(result, 100)
|
||||
print(f"平盘市场trend_adx值: {result}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
64
tests/test_volume_price_diverge.py
Normal file
64
tests/test_volume_price_diverge.py
Normal file
@ -0,0 +1,64 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""测试volume_price_diverge因子计算逻辑"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_volume_price_divergence(close_series, volume_series):
|
||||
"""量价背离:价格和成交量趋势的背离程度"""
|
||||
window = 10
|
||||
if len(close_series) < window or len(volume_series) < window:
|
||||
return None
|
||||
|
||||
# 计算价格和成交量趋势
|
||||
price_trend = sum(1 if close_series[i] > close_series[i+1] else -1 for i in range(window-1))
|
||||
volume_trend = sum(1 if volume_series[i] > volume_series[i+1] else -1 for i in range(window-1))
|
||||
|
||||
# 计算背离程度
|
||||
divergence = price_trend * volume_trend * -1 # 反向为背离
|
||||
return np.clip(divergence / (window - 1), -1, 1)
|
||||
|
||||
|
||||
def test_volume_price_divergence():
|
||||
"""测试volume_price_divergence因子计算"""
|
||||
|
||||
# 测试场景1:价格和成交量同向变动(无背离)
|
||||
print("测试场景1:价格和成交量同向变动")
|
||||
close_prices_1 = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]
|
||||
volume_prices_1 = [1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900]
|
||||
result_1 = compute_volume_price_divergence(close_prices_1, volume_prices_1)
|
||||
print(f"同向变动结果: {result_1}")
|
||||
|
||||
# 测试场景2:价格和成交量反向变动(强背离)
|
||||
print("\n测试场景2:价格和成交量反向变动")
|
||||
close_prices_2 = [109, 108, 107, 106, 105, 104, 103, 102, 101, 100]
|
||||
volume_prices_2 = [1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900]
|
||||
result_2 = compute_volume_price_divergence(close_prices_2, volume_prices_2)
|
||||
print(f"反向变动结果: {result_2}")
|
||||
|
||||
# 测试场景3:价格上升,成交量下降(背离)
|
||||
print("\n测试场景3:价格上升,成交量下降")
|
||||
close_prices_3 = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]
|
||||
volume_prices_3 = [1900, 1800, 1700, 1600, 1500, 1400, 1300, 1200, 1100, 1000]
|
||||
result_3 = compute_volume_price_divergence(close_prices_3, volume_prices_3)
|
||||
print(f"价格上涨成交量下降结果: {result_3}")
|
||||
|
||||
# 测试场景4:价格下降,成交量上升(背离)
|
||||
print("\n测试场景4:价格下降,成交量上升")
|
||||
close_prices_4 = [109, 108, 107, 106, 105, 104, 103, 102, 101, 100]
|
||||
volume_prices_4 = [1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900]
|
||||
result_4 = compute_volume_price_divergence(close_prices_4, volume_prices_4)
|
||||
print(f"价格下降成交量上升结果: {result_4}")
|
||||
|
||||
# 测试场景5:震荡市场(弱背离)
|
||||
print("\n测试场景5:震荡市场")
|
||||
close_prices_5 = [100, 100.5, 99.5, 101, 99, 100.5, 99.5, 100, 100.5, 99.5]
|
||||
volume_prices_5 = [1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900]
|
||||
result_5 = compute_volume_price_divergence(close_prices_5, volume_prices_5)
|
||||
print(f"震荡市场结果: {result_5}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_volume_price_divergence()
|
||||
63
tests/test_volume_price_diverge_validated.py
Normal file
63
tests/test_volume_price_diverge_validated.py
Normal file
@ -0,0 +1,63 @@
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from app.features.extended_factors import ExtendedFactorEngine
|
||||
|
||||
def test_volume_price_diverge_validated():
|
||||
"""测试修复后的volume_price_diverge因子"""
|
||||
# 创建因子引擎
|
||||
engine = ExtendedFactorEngine()
|
||||
|
||||
# 测试场景1:价格和成交量同向变动(强正相关)
|
||||
print("测试场景1:价格和成交量同向变动")
|
||||
close_prices1 = np.array([100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110], dtype=float)
|
||||
volume_prices1 = np.array([1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000], dtype=float)
|
||||
result1 = engine.compute_factor("volume_price_diverge", close_prices1, volume_prices1)
|
||||
print(f"同向变动结果: {result1}")
|
||||
|
||||
# 测试场景2:价格和成交量反向变动(强负相关)
|
||||
print("\n测试场景2:价格和成交量反向变动")
|
||||
close_prices2 = np.array([100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110], dtype=float)
|
||||
volume_prices2 = np.array([2000, 1900, 1800, 1700, 1600, 1500, 1400, 1300, 1200, 1100, 1000], dtype=float)
|
||||
result2 = engine.compute_factor("volume_price_diverge", close_prices2, volume_prices2)
|
||||
print(f"反向变动结果: {result2}")
|
||||
|
||||
# 测试场景3:价格上升,成交量下降
|
||||
print("\n测试场景3:价格上升,成交量下降")
|
||||
close_prices3 = np.array([100, 102, 104, 106, 108, 110, 112, 114, 116, 118, 120], dtype=float)
|
||||
volume_prices3 = np.array([1000, 950, 900, 850, 800, 750, 700, 650, 600, 550, 500], dtype=float)
|
||||
result3 = engine.compute_factor("volume_price_diverge", close_prices3, volume_prices3)
|
||||
print(f"价格上涨成交量下降结果: {result3}")
|
||||
|
||||
# 测试场景4:价格下降,成交量上升
|
||||
print("\n测试场景4:价格下降,成交量上升")
|
||||
close_prices4 = np.array([120, 118, 116, 114, 112, 110, 108, 106, 104, 102, 100], dtype=float)
|
||||
volume_prices4 = np.array([500, 550, 600, 650, 700, 750, 800, 850, 900, 950, 1000], dtype=float)
|
||||
result4 = engine.compute_factor("volume_price_diverge", close_prices4, volume_prices4)
|
||||
print(f"价格下降成交量上升结果: {result4}")
|
||||
|
||||
# 测试场景5:震荡市场
|
||||
print("\n测试场景5:震荡市场")
|
||||
close_prices5 = np.array([100, 101, 100, 101, 100, 101, 100, 101, 100, 101, 100], dtype=float)
|
||||
volume_prices5 = np.array([1000, 1100, 1000, 1100, 1000, 1100, 1000, 1100, 1000, 1100, 1000], dtype=float)
|
||||
result5 = engine.compute_factor("volume_price_diverge", close_prices5, volume_prices5)
|
||||
print(f"震荡市场结果: {result5}")
|
||||
|
||||
# 验证所有结果都在合理范围内
|
||||
print("\n验证结果范围:")
|
||||
all_results = [result1, result2, result3, result4, result5]
|
||||
for i, result in enumerate(all_results, 1):
|
||||
if result is not None:
|
||||
assert -1.0 <= result <= 1.0, f"测试场景{i}的结果超出范围: {result}"
|
||||
print(f"测试场景{i}结果 {result} 在合理范围内")
|
||||
else:
|
||||
print(f"测试场景{i}结果为 None")
|
||||
|
||||
print("\n所有测试通过!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_volume_price_diverge_validated()
|
||||
Loading…
Reference in New Issue
Block a user