111 lines
3.9 KiB
Python
111 lines
3.9 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""简化版trend_adx因子测试"""
|
||
|
||
def compute_trend_adx(close_series):
|
||
"""简化版trend_adx计算实现"""
|
||
# 标准ADX计算实现
|
||
window = 14
|
||
if len(close_series) < window + 1:
|
||
return None
|
||
|
||
# 计算+DI和-DI
|
||
plus_di = 0
|
||
minus_di = 0
|
||
tr_sum = 0
|
||
|
||
# 计算初始TR、+DM、-DM
|
||
for i in range(window):
|
||
if i + 1 >= len(close_series):
|
||
break
|
||
|
||
# 计算真实波幅(TR)
|
||
today_high = close_series[i]
|
||
today_low = close_series[i]
|
||
prev_close = close_series[i + 1]
|
||
|
||
tr = max(
|
||
abs(today_high - today_low),
|
||
abs(today_high - prev_close),
|
||
abs(today_low - prev_close)
|
||
)
|
||
tr_sum += tr
|
||
|
||
# 计算方向运动
|
||
prev_high = close_series[i + 1] if i + 1 < len(close_series) else close_series[i]
|
||
prev_low = close_series[i + 1] if i + 1 < len(close_series) else close_series[i]
|
||
|
||
plus_dm = max(0, close_series[i] - prev_high)
|
||
minus_dm = max(0, prev_low - close_series[i])
|
||
|
||
# 确保只有一项为正值
|
||
if plus_dm > minus_dm:
|
||
minus_dm = 0
|
||
elif minus_dm > plus_dm:
|
||
plus_dm = 0
|
||
else:
|
||
plus_dm = minus_dm = 0
|
||
|
||
plus_di += plus_dm
|
||
minus_di += minus_dm
|
||
|
||
# 计算+DI和-DI
|
||
if tr_sum > 0:
|
||
plus_di = (plus_di / tr_sum) * 100
|
||
minus_di = (minus_di / tr_sum) * 100
|
||
else:
|
||
plus_di = minus_di = 0
|
||
|
||
# 计算DX
|
||
dx = 0
|
||
if plus_di + minus_di > 0:
|
||
dx = (abs(plus_di - minus_di) / (plus_di + minus_di)) * 100
|
||
|
||
# ADX是DX的移动平均,这里简化为直接返回DX值,确保在0-100范围内
|
||
return max(0, min(100, dx))
|
||
|
||
|
||
def test_trend_adx():
|
||
"""测试trend_adx因子计算"""
|
||
|
||
# 测试上涨趋势
|
||
print("测试上涨趋势:")
|
||
close_prices_up = [115, 114, 113, 112, 111, 110, 109, 108, 107, 106, 105, 104, 103, 102, 101, 100]
|
||
result_up = compute_trend_adx(close_prices_up)
|
||
print(f"上涨趋势trend_adx值: {result_up}")
|
||
assert result_up is not None and 0 <= result_up <= 100, f"上涨趋势结果错误: {result_up}"
|
||
|
||
# 测试下跌趋势
|
||
print("\n测试下跌趋势:")
|
||
close_prices_down = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115]
|
||
result_down = compute_trend_adx(close_prices_down)
|
||
print(f"下跌趋势trend_adx值: {result_down}")
|
||
assert result_down is not None and 0 <= result_down <= 100, f"下跌趋势结果错误: {result_down}"
|
||
|
||
# 测试震荡市场
|
||
print("\n测试震荡市场:")
|
||
close_prices_sideways = [100, 100.5, 99.5, 101, 99, 100.5, 99.5, 100, 100.5, 99.5, 100, 100.5, 99.5, 100, 100.5, 99.5]
|
||
result_sideways = compute_trend_adx(close_prices_sideways)
|
||
print(f"震荡市场trend_adx值: {result_sideways}")
|
||
assert result_sideways is not None and 0 <= result_sideways <= 100, f"震荡市场结果错误: {result_sideways}"
|
||
|
||
# 测试数据不足
|
||
print("\n测试数据不足:")
|
||
close_prices_insufficient = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]
|
||
result_insufficient = compute_trend_adx(close_prices_insufficient)
|
||
print(f"数据不足时结果: {result_insufficient}")
|
||
assert result_insufficient is None, f"数据不足时应该返回None: {result_insufficient}"
|
||
|
||
# 测试平盘市场
|
||
print("\n测试平盘市场:")
|
||
close_prices_flat = [100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100]
|
||
result_flat = compute_trend_adx(close_prices_flat)
|
||
print(f"平盘市场trend_adx值: {result_flat}")
|
||
assert result_flat is not None and 0 <= result_flat <= 100, f"平盘市场结果错误: {result_flat}"
|
||
|
||
print("\n所有测试通过!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
test_trend_adx() |