llm-quant/app/backtest/engine.py
2025-09-29 13:12:46 +08:00

424 lines
15 KiB
Python

"""Backtest engine skeleton for daily bar simulation."""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from datetime import date
from statistics import mean, pstdev
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
from app.agents.base import AgentContext
from app.agents.departments import DepartmentManager
from app.agents.game import Decision, decide
from app.llm.metrics import record_decision as metrics_record_decision
from app.agents.registry import default_agents
from app.utils.data_access import DataBroker
from app.utils.config import get_config
from app.utils.db import db_session
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "backtest"}
def _compute_momentum(values: List[float], window: int) -> float:
if window <= 0 or len(values) < window:
return 0.0
latest = values[0]
past = values[window - 1]
if past is None or past == 0:
return 0.0
try:
return (latest / past) - 1.0
except ZeroDivisionError:
return 0.0
def _compute_volatility(values: List[float], window: int) -> float:
if len(values) < 2 or window <= 1:
return 0.0
limit = min(window, len(values) - 1)
returns: List[float] = []
for idx in range(limit):
current = values[idx]
previous = values[idx + 1]
if previous is None or previous == 0:
continue
returns.append((current / previous) - 1.0)
if len(returns) < 2:
return 0.0
return float(pstdev(returns))
def _normalize(value: Any, factor: float) -> float:
try:
numeric = float(value)
except (TypeError, ValueError):
return 0.0
if factor <= 0:
return max(0.0, min(1.0, numeric))
return max(0.0, min(1.0, numeric / factor))
@dataclass
class BtConfig:
id: str
name: str
start_date: date
end_date: date
universe: List[str]
params: Dict[str, float]
method: str = "nash"
@dataclass
class PortfolioState:
cash: float = 1_000_000.0
holdings: Dict[str, float] = field(default_factory=dict)
@dataclass
class BacktestResult:
nav_series: List[Dict[str, float]] = field(default_factory=list)
trades: List[Dict[str, str]] = field(default_factory=list)
class BacktestEngine:
"""Runs the multi-agent game inside a daily event-driven loop."""
def __init__(self, cfg: BtConfig) -> None:
self.cfg = cfg
self.agents = default_agents()
app_cfg = get_config()
weight_config = app_cfg.agent_weights.as_dict() if app_cfg.agent_weights else {}
if weight_config:
self.weights = weight_config
else:
self.weights = {agent.name: 1.0 for agent in self.agents}
self.department_manager = (
DepartmentManager(app_cfg) if app_cfg.departments else None
)
self.data_broker = DataBroker()
department_scope: set[str] = set()
for settings in app_cfg.departments.values():
department_scope.update(settings.data_scope)
base_scope = {
"daily.close",
"daily.open",
"daily.high",
"daily.low",
"daily.pct_chg",
"daily.vol",
"daily.amount",
"daily_basic.turnover_rate",
"daily_basic.turnover_rate_f",
"daily_basic.volume_ratio",
"stk_limit.up_limit",
"stk_limit.down_limit",
}
self.required_fields = sorted(base_scope | department_scope)
def load_market_data(self, trade_date: date) -> Mapping[str, Dict[str, Any]]:
"""Load per-stock feature vectors and context slices for the trade date."""
trade_date_str = trade_date.strftime("%Y%m%d")
feature_map: Dict[str, Dict[str, Any]] = {}
universe = self.cfg.universe or []
for ts_code in universe:
scope_values = self.data_broker.fetch_latest(
ts_code,
trade_date_str,
self.required_fields,
)
closes = self.data_broker.fetch_series(
"daily",
"close",
ts_code,
trade_date_str,
window=60,
)
close_values = [value for _date, value in closes]
mom20 = _compute_momentum(close_values, 20)
mom60 = _compute_momentum(close_values, 60)
volat20 = _compute_volatility(close_values, 20)
turnover_series = self.data_broker.fetch_series(
"daily_basic",
"turnover_rate",
ts_code,
trade_date_str,
window=20,
)
turnover_values = [value for _date, value in turnover_series]
turn20 = mean(turnover_values) if turnover_values else 0.0
liquidity_score = _normalize(turn20, factor=20.0)
cost_penalty = _normalize(scope_values.get("daily_basic.volume_ratio", 0.0), factor=50.0)
latest_close = scope_values.get("daily.close", 0.0)
latest_pct = scope_values.get("daily.pct_chg", 0.0)
latest_turnover = scope_values.get("daily_basic.turnover_rate", 0.0)
up_limit = scope_values.get("stk_limit.up_limit")
limit_up = False
if up_limit and latest_close:
limit_up = latest_close >= up_limit * 0.999
down_limit = scope_values.get("stk_limit.down_limit")
limit_down = False
if down_limit and latest_close:
limit_down = latest_close <= down_limit * 1.001
is_suspended = self.data_broker.fetch_flags(
"suspend",
ts_code,
trade_date_str,
"suspend_date <= ? AND (resume_date IS NULL OR resume_date > ?)",
(trade_date_str, trade_date_str),
)
features = {
"mom_20": mom20,
"mom_60": mom60,
"volat_20": volat20,
"turn_20": turn20,
"liquidity_score": liquidity_score,
"cost_penalty": cost_penalty,
"news_heat": scope_values.get("news.heat_score", 0.0),
"news_sentiment": scope_values.get("news.sentiment_index", 0.0),
"industry_heat": scope_values.get("macro.industry_heat", 0.0),
"industry_relative_mom": scope_values.get(
"macro.relative_strength",
scope_values.get("index.performance_peers", 0.0),
),
"risk_penalty": min(1.0, volat20 * 5.0),
"is_suspended": is_suspended,
"limit_up": limit_up,
"limit_down": limit_down,
"position_limit": False,
}
market_snapshot = {
"close": latest_close,
"pct_chg": latest_pct,
"turnover_rate": latest_turnover,
"volume": scope_values.get("daily.vol", 0.0),
"amount": scope_values.get("daily.amount", 0.0),
"up_limit": up_limit,
"down_limit": down_limit,
}
raw_payload = {
"scope_values": scope_values,
"close_series": closes,
"turnover_series": turnover_series,
"required_fields": self.required_fields,
}
feature_map[ts_code] = {
"features": features,
"market_snapshot": market_snapshot,
"raw": raw_payload,
}
return feature_map
def simulate_day(
self,
trade_date: date,
state: PortfolioState,
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
) -> List[Decision]:
feature_map = self.load_market_data(trade_date)
decisions: List[Decision] = []
for ts_code, payload in feature_map.items():
features = payload.get("features", {})
market_snapshot = payload.get("market_snapshot", {})
raw = payload.get("raw", {})
context = AgentContext(
ts_code=ts_code,
trade_date=trade_date.isoformat(),
features=features,
market_snapshot=market_snapshot,
raw=raw,
)
decision = decide(
context,
self.agents,
self.weights,
method=self.cfg.method,
department_manager=self.department_manager,
)
try:
metrics_record_decision(
ts_code=ts_code,
trade_date=context.trade_date,
action=decision.action.value,
confidence=decision.confidence,
summary=_extract_summary(decision),
source="backtest",
departments={
code: dept.to_dict()
for code, dept in decision.department_decisions.items()
},
)
except Exception: # noqa: BLE001
LOGGER.debug("记录决策指标失败", extra=LOG_EXTRA)
decisions.append(decision)
self.record_agent_state(context, decision)
if decision_callback:
try:
decision_callback(ts_code, trade_date, context, decision)
except Exception: # noqa: BLE001
LOGGER.exception("决策回调执行失败", extra=LOG_EXTRA)
# TODO: translate decisions into fills, holdings, and NAV updates.
_ = state
return decisions
def record_agent_state(self, context: AgentContext, decision: Decision) -> None:
payload = {
"trade_date": context.trade_date,
"ts_code": context.ts_code,
"action": decision.action.value,
"confidence": decision.confidence,
"department_votes": decision.department_votes,
"requires_review": decision.requires_review,
"departments": {
code: dept.to_dict()
for code, dept in decision.department_decisions.items()
},
}
combined_weights = dict(self.weights)
if self.department_manager:
for code, agent in self.department_manager.agents.items():
key = f"dept_{code}"
combined_weights[key] = agent.settings.weight
feasible_json = json.dumps(
[action.value for action in decision.feasible_actions],
ensure_ascii=False,
)
rows = []
for agent_name, weight in combined_weights.items():
action_scores = {
action.value: float(decision.utilities.get(action, {}).get(agent_name, 0.0))
for action in decision.utilities.keys()
}
best_action = decision.action.value
if action_scores:
best_action = max(action_scores.items(), key=lambda item: item[1])[0]
metadata: Dict[str, object] = {}
if agent_name.startswith("dept_"):
dept_code = agent_name.split("dept_", 1)[-1]
dept_decision = decision.department_decisions.get(dept_code)
if dept_decision:
metadata = {
"_summary": dept_decision.summary,
"_signals": dept_decision.signals,
"_risks": dept_decision.risks,
"_confidence": dept_decision.confidence,
}
if dept_decision.supplements:
metadata["_supplements"] = dept_decision.supplements
if dept_decision.dialogue:
metadata["_dialogue"] = dept_decision.dialogue
if dept_decision.telemetry:
metadata["_telemetry"] = dept_decision.telemetry
payload_json = {**action_scores, **metadata}
rows.append(
(
context.trade_date,
context.ts_code,
agent_name,
best_action,
json.dumps(payload_json, ensure_ascii=False),
feasible_json,
float(weight),
)
)
global_payload = {
"_confidence": decision.confidence,
"_target_weight": decision.target_weight,
"_department_votes": decision.department_votes,
"_requires_review": decision.requires_review,
"_scope_values": context.raw.get("scope_values", {}),
"_close_series": context.raw.get("close_series", []),
"_turnover_series": context.raw.get("turnover_series", []),
"_department_supplements": {
code: dept.supplements
for code, dept in decision.department_decisions.items()
if dept.supplements
},
"_department_dialogue": {
code: dept.dialogue
for code, dept in decision.department_decisions.items()
if dept.dialogue
},
"_department_telemetry": {
code: dept.telemetry
for code, dept in decision.department_decisions.items()
if dept.telemetry
},
}
rows.append(
(
context.trade_date,
context.ts_code,
"global",
decision.action.value,
json.dumps(global_payload, ensure_ascii=False),
feasible_json,
1.0,
)
)
try:
with db_session() as conn:
conn.executemany(
"""
INSERT OR REPLACE INTO agent_utils
(trade_date, ts_code, agent, action, utils, feasible, weight)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
rows,
)
except Exception:
LOGGER.exception("写入 agent_utils 失败", extra=LOG_EXTRA)
_ = payload
# TODO: persist payload into bt_trades / audit tables when schema is ready.
def run(
self,
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
) -> BacktestResult:
state = PortfolioState()
result = BacktestResult()
current = self.cfg.start_date
while current <= self.cfg.end_date:
decisions = self.simulate_day(current, state, decision_callback)
_ = decisions
current = date.fromordinal(current.toordinal() + 1)
return result
def run_backtest(
cfg: BtConfig,
*,
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
) -> BacktestResult:
engine = BacktestEngine(cfg)
result = engine.run(decision_callback=decision_callback)
with db_session() as conn:
_ = conn
# Implementation should persist bt_nav, bt_trades, and bt_report rows.
return result
def _extract_summary(decision: Decision) -> str:
for dept_decision in decision.department_decisions.values():
summary = getattr(dept_decision, "summary", "")
if summary:
return str(summary)
return ""