354 lines
13 KiB
Python
354 lines
13 KiB
Python
"""Reinforcement-learning style environment wrapping the backtest engine."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import math
|
|
from dataclasses import dataclass, replace
|
|
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
|
|
|
from .engine import BacktestEngine, BacktestResult, BtConfig
|
|
from app.agents.game import Decision
|
|
from app.agents.registry import weight_map
|
|
from app.utils.db import db_session
|
|
from app.utils.logging import get_logger
|
|
|
|
LOGGER = get_logger(__name__)
|
|
LOG_EXTRA = {"stage": "decision_env"}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ParameterSpec:
|
|
"""Defines how a scalar action dimension maps to strategy parameters."""
|
|
|
|
name: str
|
|
target: str
|
|
minimum: float = 0.0
|
|
maximum: float = 1.0
|
|
|
|
def clamp(self, value: float) -> float:
|
|
clipped = max(0.0, min(1.0, float(value)))
|
|
return self.minimum + clipped * (self.maximum - self.minimum)
|
|
|
|
|
|
@dataclass
|
|
class EpisodeMetrics:
|
|
total_return: float
|
|
max_drawdown: float
|
|
volatility: float
|
|
nav_series: List[Dict[str, float]]
|
|
trades: List[Dict[str, object]]
|
|
turnover: float
|
|
turnover_value: float
|
|
trade_count: int
|
|
risk_count: int
|
|
risk_breakdown: Dict[str, int]
|
|
|
|
@property
|
|
def sharpe_like(self) -> float:
|
|
if self.volatility <= 1e-9:
|
|
return 0.0
|
|
return self.total_return / self.volatility
|
|
|
|
|
|
class DecisionEnv:
|
|
"""Thin RL-friendly wrapper that evaluates parameter actions via backtest."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
bt_config: BtConfig,
|
|
parameter_specs: Sequence[ParameterSpec],
|
|
baseline_weights: Mapping[str, float],
|
|
reward_fn: Optional[Callable[[EpisodeMetrics], float]] = None,
|
|
disable_departments: bool = False,
|
|
) -> None:
|
|
self._template_cfg = bt_config
|
|
self._specs = list(parameter_specs)
|
|
self._baseline_weights = dict(baseline_weights)
|
|
self._reward_fn = reward_fn or self._default_reward
|
|
self._last_metrics: Optional[EpisodeMetrics] = None
|
|
self._last_action: Optional[Tuple[float, ...]] = None
|
|
self._episode = 0
|
|
self._disable_departments = bool(disable_departments)
|
|
|
|
@property
|
|
def action_dim(self) -> int:
|
|
return len(self._specs)
|
|
|
|
def reset(self) -> Dict[str, float]:
|
|
self._episode += 1
|
|
self._last_metrics = None
|
|
self._last_action = None
|
|
return {
|
|
"episode": float(self._episode),
|
|
"baseline_return": 0.0,
|
|
}
|
|
|
|
def step(self, action: Sequence[float]) -> Tuple[Dict[str, float], float, bool, Dict[str, object]]:
|
|
if len(action) != self.action_dim:
|
|
raise ValueError(f"expected action length {self.action_dim}, got {len(action)}")
|
|
action_array = [float(val) for val in action]
|
|
self._last_action = tuple(action_array)
|
|
|
|
weights = self._build_weights(action_array)
|
|
LOGGER.info("episode=%s action=%s weights=%s", self._episode, action_array, weights, extra=LOG_EXTRA)
|
|
|
|
cfg = replace(self._template_cfg)
|
|
engine = BacktestEngine(cfg)
|
|
engine.weights = weight_map(weights)
|
|
if self._disable_departments:
|
|
engine.department_manager = None
|
|
|
|
self._clear_portfolio_records()
|
|
|
|
try:
|
|
result = engine.run()
|
|
except Exception as exc: # noqa: BLE001
|
|
LOGGER.exception("backtest failed under action", extra={**LOG_EXTRA, "error": str(exc)})
|
|
info = {"error": str(exc)}
|
|
return {"failure": 1.0}, -1.0, True, info
|
|
|
|
snapshots, trades_override = self._fetch_portfolio_records()
|
|
metrics = self._compute_metrics(
|
|
result,
|
|
nav_override=snapshots if snapshots else None,
|
|
trades_override=trades_override if trades_override else None,
|
|
)
|
|
reward = float(self._reward_fn(metrics))
|
|
self._last_metrics = metrics
|
|
|
|
observation = {
|
|
"total_return": metrics.total_return,
|
|
"max_drawdown": metrics.max_drawdown,
|
|
"volatility": metrics.volatility,
|
|
"sharpe_like": metrics.sharpe_like,
|
|
"turnover": metrics.turnover,
|
|
"turnover_value": metrics.turnover_value,
|
|
"trade_count": float(metrics.trade_count),
|
|
"risk_count": float(metrics.risk_count),
|
|
}
|
|
info = {
|
|
"nav_series": metrics.nav_series,
|
|
"trades": metrics.trades,
|
|
"weights": weights,
|
|
"risk_breakdown": metrics.risk_breakdown,
|
|
"risk_events": getattr(result, "risk_events", []),
|
|
"portfolio_snapshots": snapshots,
|
|
"portfolio_trades": trades_override,
|
|
}
|
|
return observation, reward, True, info
|
|
|
|
def _build_weights(self, action: Sequence[float]) -> Dict[str, float]:
|
|
weights = dict(self._baseline_weights)
|
|
for idx, spec in enumerate(self._specs):
|
|
value = spec.clamp(action[idx])
|
|
if spec.target.startswith("agent_weights."):
|
|
agent_name = spec.target.split(".", 1)[1]
|
|
weights[agent_name] = value
|
|
else:
|
|
LOGGER.debug("暂未支持的参数目标:%s", spec.target, extra=LOG_EXTRA)
|
|
return weights
|
|
|
|
def _compute_metrics(
|
|
self,
|
|
result: BacktestResult,
|
|
*,
|
|
nav_override: Optional[List[Dict[str, Any]]] = None,
|
|
trades_override: Optional[List[Dict[str, Any]]] = None,
|
|
) -> EpisodeMetrics:
|
|
nav_series = nav_override if nav_override is not None else result.nav_series or []
|
|
trades = trades_override if trades_override is not None else result.trades
|
|
|
|
if not nav_series:
|
|
risk_breakdown: Dict[str, int] = {}
|
|
for event in getattr(result, "risk_events", []) or []:
|
|
reason = str(event.get("reason") or "unknown")
|
|
risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1
|
|
return EpisodeMetrics(
|
|
total_return=0.0,
|
|
max_drawdown=0.0,
|
|
volatility=0.0,
|
|
nav_series=[],
|
|
trades=trades or [],
|
|
turnover=0.0,
|
|
turnover_value=0.0,
|
|
trade_count=len(trades or []),
|
|
risk_count=len(getattr(result, "risk_events", []) or []),
|
|
risk_breakdown=risk_breakdown,
|
|
)
|
|
|
|
nav_values = [row.get("nav", 0.0) for row in nav_series]
|
|
if not nav_values or nav_values[0] == 0:
|
|
base_nav = nav_values[0] if nav_values else 1.0
|
|
else:
|
|
base_nav = nav_values[0]
|
|
|
|
returns = [(nav / base_nav) - 1.0 for nav in nav_values]
|
|
total_return = returns[-1]
|
|
|
|
peak = nav_values[0]
|
|
max_drawdown = 0.0
|
|
for nav in nav_values:
|
|
if nav > peak:
|
|
peak = nav
|
|
drawdown = (peak - nav) / peak if peak else 0.0
|
|
max_drawdown = max(max_drawdown, drawdown)
|
|
|
|
diffs = [nav_values[idx] - nav_values[idx - 1] for idx in range(1, len(nav_values))]
|
|
if diffs:
|
|
mean_diff = sum(diffs) / len(diffs)
|
|
variance = sum((diff - mean_diff) ** 2 for diff in diffs) / len(diffs)
|
|
volatility = math.sqrt(variance) / base_nav
|
|
else:
|
|
volatility = 0.0
|
|
|
|
turnover_value = 0.0
|
|
turnover_ratios: List[float] = []
|
|
for row in nav_series:
|
|
turnover_raw = float(row.get("turnover", 0.0) or 0.0)
|
|
turnover_value += turnover_raw
|
|
ratio = row.get("turnover_ratio")
|
|
if ratio is not None:
|
|
try:
|
|
turnover_ratios.append(float(ratio))
|
|
continue
|
|
except (TypeError, ValueError):
|
|
turnover_ratios.append(0.0)
|
|
continue
|
|
nav_val = float(row.get("nav", 0.0) or 0.0)
|
|
if nav_val > 0:
|
|
turnover_ratios.append(turnover_raw / nav_val)
|
|
else:
|
|
turnover_ratios.append(0.0)
|
|
avg_turnover_ratio = sum(turnover_ratios) / len(turnover_ratios) if turnover_ratios else 0.0
|
|
risk_events = getattr(result, "risk_events", []) or []
|
|
risk_breakdown: Dict[str, int] = {}
|
|
for event in risk_events:
|
|
reason = str(event.get("reason") or "unknown")
|
|
risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1
|
|
|
|
return EpisodeMetrics(
|
|
total_return=float(total_return),
|
|
max_drawdown=float(max_drawdown),
|
|
volatility=volatility,
|
|
nav_series=nav_series,
|
|
trades=trades or [],
|
|
turnover=float(avg_turnover_ratio),
|
|
turnover_value=float(turnover_value),
|
|
trade_count=len(trades or []),
|
|
risk_count=len(risk_events),
|
|
risk_breakdown=risk_breakdown,
|
|
)
|
|
|
|
@staticmethod
|
|
def _default_reward(metrics: EpisodeMetrics) -> float:
|
|
risk_penalty = 0.05 * metrics.risk_count
|
|
turnover_penalty = 0.1 * metrics.turnover
|
|
penalty = 0.5 * metrics.max_drawdown + risk_penalty + turnover_penalty
|
|
return metrics.total_return - penalty
|
|
|
|
@property
|
|
def last_metrics(self) -> Optional[EpisodeMetrics]:
|
|
return self._last_metrics
|
|
|
|
@property
|
|
def last_action(self) -> Optional[Tuple[float, ...]]:
|
|
return self._last_action
|
|
|
|
def _clear_portfolio_records(self) -> None:
|
|
start = self._template_cfg.start_date.isoformat()
|
|
end = self._template_cfg.end_date.isoformat()
|
|
try:
|
|
with db_session() as conn:
|
|
conn.execute("DELETE FROM portfolio_positions")
|
|
conn.execute(
|
|
"DELETE FROM portfolio_snapshots WHERE trade_date BETWEEN ? AND ?",
|
|
(start, end),
|
|
)
|
|
conn.execute(
|
|
"DELETE FROM portfolio_trades WHERE trade_date BETWEEN ? AND ?",
|
|
(start, end),
|
|
)
|
|
except Exception: # noqa: BLE001
|
|
LOGGER.exception("清理投资组合记录失败", extra=LOG_EXTRA)
|
|
|
|
def _fetch_portfolio_records(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
|
start = self._template_cfg.start_date.isoformat()
|
|
end = self._template_cfg.end_date.isoformat()
|
|
snapshots: List[Dict[str, Any]] = []
|
|
trades: List[Dict[str, Any]] = []
|
|
try:
|
|
with db_session(read_only=True) as conn:
|
|
snapshot_rows = conn.execute(
|
|
"""
|
|
SELECT trade_date, total_value, cash, invested_value,
|
|
unrealized_pnl, realized_pnl, net_flow, exposure, metadata
|
|
FROM portfolio_snapshots
|
|
WHERE trade_date BETWEEN ? AND ?
|
|
ORDER BY trade_date
|
|
""",
|
|
(start, end),
|
|
).fetchall()
|
|
trade_rows = conn.execute(
|
|
"""
|
|
SELECT id, trade_date, ts_code, action, quantity, price, fee, source, metadata
|
|
FROM portfolio_trades
|
|
WHERE trade_date BETWEEN ? AND ?
|
|
ORDER BY trade_date, id
|
|
""",
|
|
(start, end),
|
|
).fetchall()
|
|
except Exception: # noqa: BLE001
|
|
LOGGER.exception("读取投资组合记录失败", extra=LOG_EXTRA)
|
|
return snapshots, trades
|
|
|
|
for row in snapshot_rows:
|
|
metadata = self._loads(row["metadata"], {})
|
|
snapshots.append(
|
|
{
|
|
"trade_date": row["trade_date"],
|
|
"nav": float(row["total_value"] or 0.0),
|
|
"cash": float(row["cash"] or 0.0),
|
|
"market_value": float(row["invested_value"] or 0.0),
|
|
"unrealized_pnl": float(row["unrealized_pnl"] or 0.0),
|
|
"realized_pnl": float(row["realized_pnl"] or 0.0),
|
|
"net_flow": float(row["net_flow"] or 0.0),
|
|
"exposure": float(row["exposure"] or 0.0),
|
|
"turnover": float(metadata.get("turnover_value", 0.0) or 0.0),
|
|
"turnover_ratio": float(metadata.get("turnover_ratio", 0.0) or 0.0),
|
|
"holdings": metadata.get("holdings"),
|
|
"trade_count": metadata.get("trade_count"),
|
|
}
|
|
)
|
|
|
|
for row in trade_rows:
|
|
metadata = self._loads(row["metadata"], {})
|
|
trades.append(
|
|
{
|
|
"id": row["id"],
|
|
"trade_date": row["trade_date"],
|
|
"ts_code": row["ts_code"],
|
|
"action": row["action"],
|
|
"quantity": float(row["quantity"] or 0.0),
|
|
"price": float(row["price"] or 0.0),
|
|
"fee": float(row["fee"] or 0.0),
|
|
"source": row["source"],
|
|
"metadata": metadata,
|
|
}
|
|
)
|
|
|
|
return snapshots, trades
|
|
|
|
@staticmethod
|
|
def _loads(payload: Any, default: Any) -> Any:
|
|
if not payload:
|
|
return default
|
|
if isinstance(payload, (dict, list)):
|
|
return payload
|
|
if isinstance(payload, str):
|
|
try:
|
|
return json.loads(payload)
|
|
except json.JSONDecodeError:
|
|
return default
|
|
return default
|