211 lines
5.9 KiB
Python
211 lines
5.9 KiB
Python
"""Tests for BacktestEngine risk-aware execution."""
|
|
from __future__ import annotations
|
|
|
|
from datetime import date
|
|
|
|
import pytest
|
|
|
|
import json
|
|
|
|
from app.agents.base import AgentAction, AgentContext
|
|
from app.agents.game import Decision
|
|
from app.backtest.engine import (
|
|
BacktestEngine,
|
|
BacktestResult,
|
|
BtConfig,
|
|
PortfolioState,
|
|
_persist_backtest_results,
|
|
)
|
|
from app.data.schema import initialize_database
|
|
from app.utils.config import DataPaths, get_config
|
|
from app.utils.db import db_session
|
|
|
|
|
|
def _make_context(price: float, features: dict | None = None) -> AgentContext:
|
|
scope_values = {"daily.close": price}
|
|
return AgentContext(
|
|
ts_code="000001.SZ",
|
|
trade_date="2025-01-10",
|
|
features=features or {},
|
|
market_snapshot={},
|
|
raw={"scope_values": scope_values},
|
|
)
|
|
|
|
|
|
def _make_decision(action: AgentAction, target_weight: float = 0.0) -> Decision:
|
|
return Decision(
|
|
action=action,
|
|
confidence=0.8,
|
|
target_weight=target_weight,
|
|
feasible_actions=[],
|
|
utilities={},
|
|
)
|
|
|
|
|
|
def _engine_with_params(params: dict[str, float]) -> BacktestEngine:
|
|
cfg = BtConfig(
|
|
id="test",
|
|
name="test",
|
|
start_date=date(2025, 1, 10),
|
|
end_date=date(2025, 1, 10),
|
|
universe=["000001.SZ"],
|
|
params=params,
|
|
)
|
|
return BacktestEngine(cfg)
|
|
|
|
|
|
@pytest.fixture()
|
|
def isolated_db(tmp_path):
|
|
cfg = get_config()
|
|
original_paths = cfg.data_paths
|
|
tmp_root = tmp_path / "data"
|
|
tmp_root.mkdir(parents=True, exist_ok=True)
|
|
cfg.data_paths = DataPaths(root=tmp_root)
|
|
initialize_database()
|
|
try:
|
|
yield
|
|
finally:
|
|
cfg.data_paths = original_paths
|
|
|
|
|
|
def test_buy_respects_risk_caps():
|
|
engine = _engine_with_params(
|
|
{
|
|
"max_position_weight": 0.2,
|
|
"fee_rate": 0.0,
|
|
"slippage_bps": 0.0,
|
|
"max_daily_turnover_ratio": 1.0,
|
|
}
|
|
)
|
|
state = PortfolioState(cash=100_000.0)
|
|
result = BacktestResult()
|
|
features = {
|
|
"liquidity_score": 0.7,
|
|
"risk_penalty": 0.25,
|
|
}
|
|
context = _make_context(100.0, features)
|
|
decision = _make_decision(AgentAction.BUY_L, target_weight=0.5)
|
|
|
|
engine._apply_portfolio_updates(
|
|
date(2025, 1, 10),
|
|
state,
|
|
[("000001.SZ", context, decision)],
|
|
result,
|
|
)
|
|
|
|
expected_qty = (100_000.0 * 0.2 * (1 - 0.25)) / 100.0
|
|
assert state.holdings["000001.SZ"] == pytest.approx(expected_qty)
|
|
assert state.cash == pytest.approx(100_000.0 - expected_qty * 100.0)
|
|
assert result.trades and result.trades[0]["status"] == "executed"
|
|
assert result.nav_series[0]["turnover"] == pytest.approx(expected_qty * 100.0)
|
|
|
|
|
|
def test_buy_blocked_by_limit_up_records_risk():
|
|
engine = _engine_with_params({})
|
|
state = PortfolioState(cash=50_000.0)
|
|
result = BacktestResult()
|
|
features = {"limit_up": True}
|
|
context = _make_context(100.0, features)
|
|
decision = _make_decision(AgentAction.BUY_M, target_weight=0.1)
|
|
|
|
engine._apply_portfolio_updates(
|
|
date(2025, 1, 10),
|
|
state,
|
|
[("000001.SZ", context, decision)],
|
|
result,
|
|
)
|
|
|
|
assert "000001.SZ" not in state.holdings
|
|
assert not result.trades
|
|
assert result.risk_events
|
|
assert result.risk_events[0]["reason"] == "limit_up"
|
|
|
|
|
|
def test_sell_applies_slippage_and_fee():
|
|
engine = _engine_with_params(
|
|
{
|
|
"max_position_weight": 0.3,
|
|
"fee_rate": 0.001,
|
|
"slippage_bps": 20.0,
|
|
"max_daily_turnover_ratio": 1.0,
|
|
}
|
|
)
|
|
state = PortfolioState(
|
|
cash=0.0,
|
|
holdings={"000001.SZ": 100.0},
|
|
cost_basis={"000001.SZ": 90.0},
|
|
opened_dates={"000001.SZ": "2024-12-01"},
|
|
)
|
|
result = BacktestResult()
|
|
context = _make_context(100.0, {})
|
|
decision = _make_decision(AgentAction.SELL)
|
|
|
|
engine._apply_portfolio_updates(
|
|
date(2025, 1, 10),
|
|
state,
|
|
[("000001.SZ", context, decision)],
|
|
result,
|
|
)
|
|
|
|
trade = result.trades[0]
|
|
assert pytest.approx(trade["price"], rel=1e-6) == 100.0 * (1 - 0.002)
|
|
assert pytest.approx(trade["fee"], rel=1e-6) == trade["value"] * 0.001
|
|
assert state.cash == pytest.approx(trade["value"] - trade["fee"])
|
|
assert state.realized_pnl == pytest.approx((trade["price"] - 90.0) * 100 - trade["fee"])
|
|
assert not state.holdings
|
|
assert result.nav_series[0]["turnover"] == pytest.approx(trade["value"])
|
|
assert not result.risk_events
|
|
|
|
|
|
def test_persist_backtest_results_saves_risk_events(isolated_db):
|
|
cfg = BtConfig(
|
|
id="risk_cfg",
|
|
name="risk",
|
|
start_date=date(2025, 1, 10),
|
|
end_date=date(2025, 1, 10),
|
|
universe=["000001.SZ"],
|
|
params={},
|
|
)
|
|
result = BacktestResult()
|
|
result.nav_series = [
|
|
{
|
|
"trade_date": "2025-01-10",
|
|
"nav": 100.0,
|
|
"cash": 100.0,
|
|
"market_value": 0.0,
|
|
"realized_pnl": 0.0,
|
|
"unrealized_pnl": 0.0,
|
|
"turnover": 0.0,
|
|
}
|
|
]
|
|
result.risk_events = [
|
|
{
|
|
"trade_date": "2025-01-10",
|
|
"ts_code": "000001.SZ",
|
|
"reason": "limit_up",
|
|
"action": "buy_l",
|
|
"target_weight": 0.3,
|
|
"confidence": 0.8,
|
|
}
|
|
]
|
|
|
|
_persist_backtest_results(cfg, result)
|
|
|
|
with db_session(read_only=True) as conn:
|
|
risk_row = conn.execute(
|
|
"SELECT reason, metadata FROM bt_risk_events WHERE cfg_id = ?",
|
|
(cfg.id,),
|
|
).fetchone()
|
|
assert risk_row is not None
|
|
assert risk_row["reason"] == "limit_up"
|
|
metadata = json.loads(risk_row["metadata"])
|
|
assert metadata["action"] == "buy_l"
|
|
|
|
summary_row = conn.execute(
|
|
"SELECT summary FROM bt_report WHERE cfg_id = ?",
|
|
(cfg.id,),
|
|
).fetchone()
|
|
summary = json.loads(summary_row["summary"])
|
|
assert summary["risk_events"] == 1
|
|
assert summary["risk_breakdown"]["limit_up"] == 1
|