"""Reinforcement-learning style environment wrapping the backtest engine.""" from __future__ import annotations import json import math from dataclasses import dataclass, replace from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple from .engine import BacktestEngine, BacktestResult, BtConfig from app.agents.game import Decision from app.agents.registry import weight_map from app.utils.db import db_session from app.utils.logging import get_logger LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "decision_env"} @dataclass(frozen=True) class ParameterSpec: """Defines how a scalar action dimension maps to strategy parameters.""" name: str target: str minimum: float = 0.0 maximum: float = 1.0 def clamp(self, value: float) -> float: clipped = max(0.0, min(1.0, float(value))) return self.minimum + clipped * (self.maximum - self.minimum) @dataclass class EpisodeMetrics: total_return: float max_drawdown: float volatility: float nav_series: List[Dict[str, float]] trades: List[Dict[str, object]] turnover: float turnover_value: float trade_count: int risk_count: int risk_breakdown: Dict[str, int] @property def sharpe_like(self) -> float: if self.volatility <= 1e-9: return 0.0 return self.total_return / self.volatility class DecisionEnv: """Thin RL-friendly wrapper that evaluates parameter actions via backtest.""" def __init__( self, *, bt_config: BtConfig, parameter_specs: Sequence[ParameterSpec], baseline_weights: Mapping[str, float], reward_fn: Optional[Callable[[EpisodeMetrics], float]] = None, disable_departments: bool = False, ) -> None: self._template_cfg = bt_config self._specs = list(parameter_specs) self._baseline_weights = dict(baseline_weights) self._reward_fn = reward_fn or self._default_reward self._last_metrics: Optional[EpisodeMetrics] = None self._last_action: Optional[Tuple[float, ...]] = None self._episode = 0 self._disable_departments = bool(disable_departments) @property def action_dim(self) -> int: return len(self._specs) def reset(self) -> Dict[str, float]: self._episode += 1 self._last_metrics = None self._last_action = None return { "episode": float(self._episode), "baseline_return": 0.0, } def step(self, action: Sequence[float]) -> Tuple[Dict[str, float], float, bool, Dict[str, object]]: if len(action) != self.action_dim: raise ValueError(f"expected action length {self.action_dim}, got {len(action)}") action_array = [float(val) for val in action] self._last_action = tuple(action_array) weights = self._build_weights(action_array) LOGGER.info("episode=%s action=%s weights=%s", self._episode, action_array, weights, extra=LOG_EXTRA) cfg = replace(self._template_cfg) engine = BacktestEngine(cfg) engine.weights = weight_map(weights) if self._disable_departments: engine.department_manager = None self._clear_portfolio_records() try: result = engine.run() except Exception as exc: # noqa: BLE001 LOGGER.exception("backtest failed under action", extra={**LOG_EXTRA, "error": str(exc)}) info = {"error": str(exc)} return {"failure": 1.0}, -1.0, True, info snapshots, trades_override = self._fetch_portfolio_records() metrics = self._compute_metrics( result, nav_override=snapshots if snapshots else None, trades_override=trades_override if trades_override else None, ) reward = float(self._reward_fn(metrics)) self._last_metrics = metrics observation = { "total_return": metrics.total_return, "max_drawdown": metrics.max_drawdown, "volatility": metrics.volatility, "sharpe_like": metrics.sharpe_like, "turnover": metrics.turnover, "turnover_value": metrics.turnover_value, "trade_count": float(metrics.trade_count), "risk_count": float(metrics.risk_count), } info = { "nav_series": metrics.nav_series, "trades": metrics.trades, "weights": weights, "risk_breakdown": metrics.risk_breakdown, "risk_events": getattr(result, "risk_events", []), "portfolio_snapshots": snapshots, "portfolio_trades": trades_override, } return observation, reward, True, info def _build_weights(self, action: Sequence[float]) -> Dict[str, float]: weights = dict(self._baseline_weights) for idx, spec in enumerate(self._specs): value = spec.clamp(action[idx]) if spec.target.startswith("agent_weights."): agent_name = spec.target.split(".", 1)[1] weights[agent_name] = value else: LOGGER.debug("暂未支持的参数目标:%s", spec.target, extra=LOG_EXTRA) return weights def _compute_metrics( self, result: BacktestResult, *, nav_override: Optional[List[Dict[str, Any]]] = None, trades_override: Optional[List[Dict[str, Any]]] = None, ) -> EpisodeMetrics: nav_series = nav_override if nav_override is not None else result.nav_series or [] trades = trades_override if trades_override is not None else result.trades if not nav_series: risk_breakdown: Dict[str, int] = {} for event in getattr(result, "risk_events", []) or []: reason = str(event.get("reason") or "unknown") risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1 return EpisodeMetrics( total_return=0.0, max_drawdown=0.0, volatility=0.0, nav_series=[], trades=trades or [], turnover=0.0, turnover_value=0.0, trade_count=len(trades or []), risk_count=len(getattr(result, "risk_events", []) or []), risk_breakdown=risk_breakdown, ) nav_values = [row.get("nav", 0.0) for row in nav_series] if not nav_values or nav_values[0] == 0: base_nav = nav_values[0] if nav_values else 1.0 else: base_nav = nav_values[0] returns = [(nav / base_nav) - 1.0 for nav in nav_values] total_return = returns[-1] peak = nav_values[0] max_drawdown = 0.0 for nav in nav_values: if nav > peak: peak = nav drawdown = (peak - nav) / peak if peak else 0.0 max_drawdown = max(max_drawdown, drawdown) diffs = [nav_values[idx] - nav_values[idx - 1] for idx in range(1, len(nav_values))] if diffs: mean_diff = sum(diffs) / len(diffs) variance = sum((diff - mean_diff) ** 2 for diff in diffs) / len(diffs) volatility = math.sqrt(variance) / base_nav else: volatility = 0.0 turnover_value = 0.0 turnover_ratios: List[float] = [] for row in nav_series: turnover_raw = float(row.get("turnover", 0.0) or 0.0) turnover_value += turnover_raw ratio = row.get("turnover_ratio") if ratio is not None: try: turnover_ratios.append(float(ratio)) continue except (TypeError, ValueError): turnover_ratios.append(0.0) continue nav_val = float(row.get("nav", 0.0) or 0.0) if nav_val > 0: turnover_ratios.append(turnover_raw / nav_val) else: turnover_ratios.append(0.0) avg_turnover_ratio = sum(turnover_ratios) / len(turnover_ratios) if turnover_ratios else 0.0 risk_events = getattr(result, "risk_events", []) or [] risk_breakdown: Dict[str, int] = {} for event in risk_events: reason = str(event.get("reason") or "unknown") risk_breakdown[reason] = risk_breakdown.get(reason, 0) + 1 return EpisodeMetrics( total_return=float(total_return), max_drawdown=float(max_drawdown), volatility=volatility, nav_series=nav_series, trades=trades or [], turnover=float(avg_turnover_ratio), turnover_value=float(turnover_value), trade_count=len(trades or []), risk_count=len(risk_events), risk_breakdown=risk_breakdown, ) @staticmethod def _default_reward(metrics: EpisodeMetrics) -> float: risk_penalty = 0.05 * metrics.risk_count turnover_penalty = 0.1 * metrics.turnover penalty = 0.5 * metrics.max_drawdown + risk_penalty + turnover_penalty return metrics.total_return - penalty @property def last_metrics(self) -> Optional[EpisodeMetrics]: return self._last_metrics @property def last_action(self) -> Optional[Tuple[float, ...]]: return self._last_action def _clear_portfolio_records(self) -> None: start = self._template_cfg.start_date.isoformat() end = self._template_cfg.end_date.isoformat() try: with db_session() as conn: conn.execute("DELETE FROM portfolio_positions") conn.execute( "DELETE FROM portfolio_snapshots WHERE trade_date BETWEEN ? AND ?", (start, end), ) conn.execute( "DELETE FROM portfolio_trades WHERE trade_date BETWEEN ? AND ?", (start, end), ) except Exception: # noqa: BLE001 LOGGER.exception("清理投资组合记录失败", extra=LOG_EXTRA) def _fetch_portfolio_records(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: start = self._template_cfg.start_date.isoformat() end = self._template_cfg.end_date.isoformat() snapshots: List[Dict[str, Any]] = [] trades: List[Dict[str, Any]] = [] try: with db_session(read_only=True) as conn: snapshot_rows = conn.execute( """ SELECT trade_date, total_value, cash, invested_value, unrealized_pnl, realized_pnl, net_flow, exposure, metadata FROM portfolio_snapshots WHERE trade_date BETWEEN ? AND ? ORDER BY trade_date """, (start, end), ).fetchall() trade_rows = conn.execute( """ SELECT id, trade_date, ts_code, action, quantity, price, fee, source, metadata FROM portfolio_trades WHERE trade_date BETWEEN ? AND ? ORDER BY trade_date, id """, (start, end), ).fetchall() except Exception: # noqa: BLE001 LOGGER.exception("读取投资组合记录失败", extra=LOG_EXTRA) return snapshots, trades for row in snapshot_rows: metadata = self._loads(row["metadata"], {}) snapshots.append( { "trade_date": row["trade_date"], "nav": float(row["total_value"] or 0.0), "cash": float(row["cash"] or 0.0), "market_value": float(row["invested_value"] or 0.0), "unrealized_pnl": float(row["unrealized_pnl"] or 0.0), "realized_pnl": float(row["realized_pnl"] or 0.0), "net_flow": float(row["net_flow"] or 0.0), "exposure": float(row["exposure"] or 0.0), "turnover": float(metadata.get("turnover_value", 0.0) or 0.0), "turnover_ratio": float(metadata.get("turnover_ratio", 0.0) or 0.0), "holdings": metadata.get("holdings"), "trade_count": metadata.get("trade_count"), } ) for row in trade_rows: metadata = self._loads(row["metadata"], {}) trades.append( { "id": row["id"], "trade_date": row["trade_date"], "ts_code": row["ts_code"], "action": row["action"], "quantity": float(row["quantity"] or 0.0), "price": float(row["price"] or 0.0), "fee": float(row["fee"] or 0.0), "source": row["source"], "metadata": metadata, } ) return snapshots, trades @staticmethod def _loads(payload: Any, default: Any) -> Any: if not payload: return default if isinstance(payload, (dict, list)): return payload if isinstance(payload, str): try: return json.loads(payload) except json.JSONDecodeError: return default return default