diff --git a/app/backtest/decision_env.py b/app/backtest/decision_env.py index fb44906..cb1aa40 100644 --- a/app/backtest/decision_env.py +++ b/app/backtest/decision_env.py @@ -7,8 +7,9 @@ import copy from dataclasses import dataclass, replace from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple -from .engine import BacktestEngine, BacktestResult, BtConfig -from app.agents.game import Decision +from datetime import date + +from .engine import BacktestEngine, BacktestResult, BacktestSession, BtConfig from app.agents.registry import weight_map from app.utils.db import db_session from app.utils.logging import get_logger @@ -82,6 +83,10 @@ class DecisionEnv: self._last_department_controls: Optional[Dict[str, Dict[str, Any]]] = None self._episode = 0 self._disable_departments = bool(disable_departments) + self._engine: Optional[BacktestEngine] = None + self._session: Optional[BacktestSession] = None + self._cumulative_reward = 0.0 + self._day_index = 0 @property def action_dim(self) -> int: @@ -96,9 +101,30 @@ class DecisionEnv: self._last_metrics = None self._last_action = None self._last_department_controls = None + self._cumulative_reward = 0.0 + self._day_index = 0 + + cfg = replace(self._template_cfg) + self._engine = BacktestEngine(cfg) + self._engine.weights = weight_map(self._baseline_weights) + if self._disable_departments: + self._engine.department_manager = None + + self._clear_portfolio_records() + + self._session = self._engine.start_session() return { "episode": float(self._episode), - "baseline_return": 0.0, + "day_index": 0.0, + "date_ord": float(self._template_cfg.start_date.toordinal()), + "nav": float(self._session.state.cash), + "total_return": 0.0, + "max_drawdown": 0.0, + "volatility": 0.0, + "turnover": 0.0, + "sharpe_like": 0.0, + "trade_count": 0.0, + "risk_count": 0.0, } def step(self, action: Sequence[float]) -> Tuple[Dict[str, float], float, bool, Dict[str, object]]: @@ -117,55 +143,56 @@ class DecisionEnv: extra=LOG_EXTRA, ) - cfg = replace(self._template_cfg) - engine = BacktestEngine(cfg) + engine = self._engine + session = self._session + if engine is None or session is None: + raise RuntimeError("environment not initialised; call reset() before step()") + engine.weights = weight_map(weights) if self._disable_departments: + applied_controls = {} engine.department_manager = None - applied_controls: Dict[str, Dict[str, Any]] = {} else: applied_controls = self._apply_department_controls(engine, department_controls) - self._clear_portfolio_records() - try: - result = engine.run() + records, done = engine.step_session(session) except Exception as exc: # noqa: BLE001 LOGGER.exception("backtest failed under action", extra={**LOG_EXTRA, "error": str(exc)}) info = {"error": str(exc)} return {"failure": 1.0}, -1.0, True, info + records_list = list(records) if records is not None else [] + snapshots, trades_override = self._fetch_portfolio_records() metrics = self._compute_metrics( - result, + session.result, nav_override=snapshots if snapshots else None, trades_override=trades_override if trades_override else None, ) - reward = float(self._reward_fn(metrics)) + + total_reward = float(self._reward_fn(metrics)) + reward = total_reward - self._cumulative_reward + self._cumulative_reward = total_reward self._last_metrics = metrics - observation = { - "total_return": metrics.total_return, - "max_drawdown": metrics.max_drawdown, - "volatility": metrics.volatility, - "sharpe_like": metrics.sharpe_like, - "turnover": metrics.turnover, - "turnover_value": metrics.turnover_value, - "trade_count": float(metrics.trade_count), - "risk_count": float(metrics.risk_count), - } + observation = self._build_observation(metrics, records_list, done) + observation["turnover_value"] = metrics.turnover_value info = { "nav_series": metrics.nav_series, "trades": metrics.trades, "weights": weights, "risk_breakdown": metrics.risk_breakdown, - "risk_events": getattr(result, "risk_events", []), + "risk_events": getattr(session.result, "risk_events", []), "portfolio_snapshots": snapshots, "portfolio_trades": trades_override, "department_controls": applied_controls, + "session_done": done, + "raw_records": records_list, } self._last_department_controls = applied_controls - return observation, reward, True, info + self._day_index += 1 + return observation, reward, done, info def _prepare_actions( self, @@ -408,6 +435,51 @@ class DecisionEnv: penalty = 0.5 * metrics.max_drawdown + risk_penalty + turnover_penalty return metrics.total_return - penalty + def _build_observation( + self, + metrics: EpisodeMetrics, + records: Sequence[Dict[str, Any]] | None, + done: bool, + ) -> Dict[str, float]: + observation: Dict[str, float] = { + "day_index": float(self._day_index + 1), + "total_return": metrics.total_return, + "max_drawdown": metrics.max_drawdown, + "volatility": metrics.volatility, + "sharpe_like": metrics.sharpe_like, + "turnover": metrics.turnover, + "trade_count": float(metrics.trade_count), + "risk_count": float(metrics.risk_count), + "done": 1.0 if done else 0.0, + } + + latest_snapshot = metrics.nav_series[-1] if metrics.nav_series else None + if latest_snapshot: + observation["nav"] = float(latest_snapshot.get("nav", 0.0) or 0.0) + observation["cash"] = float(latest_snapshot.get("cash", 0.0) or 0.0) + observation["market_value"] = float(latest_snapshot.get("market_value", 0.0) or 0.0) + trade_date = latest_snapshot.get("trade_date") + if isinstance(trade_date, date): + observation["date_ord"] = float(trade_date.toordinal()) + elif isinstance(trade_date, str): + try: + parsed = date.fromisoformat(trade_date) + except ValueError: + parsed = None + if parsed: + observation["date_ord"] = float(parsed.toordinal()) + if "turnover_ratio" in latest_snapshot and latest_snapshot["turnover_ratio"] is not None: + try: + observation["turnover_ratio"] = float(latest_snapshot["turnover_ratio"]) + except (TypeError, ValueError): + observation["turnover_ratio"] = 0.0 + + # Include a simple proxy for action effect size when available + if records: + observation["record_count"] = float(len(records)) + + return observation + @property def last_metrics(self) -> Optional[EpisodeMetrics]: return self._last_metrics diff --git a/app/backtest/engine.py b/app/backtest/engine.py index 52cba53..ab68657 100644 --- a/app/backtest/engine.py +++ b/app/backtest/engine.py @@ -4,7 +4,7 @@ from __future__ import annotations import json from dataclasses import dataclass, field from datetime import date -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple from app.agents.base import AgentAction, AgentContext from app.agents.departments import DepartmentManager @@ -72,6 +72,15 @@ class BacktestResult: risk_events: List[Dict[str, object]] = field(default_factory=list) +@dataclass +class BacktestSession: + """Holds the mutable state for incremental backtest execution.""" + + state: PortfolioState + result: BacktestResult + current_date: date + + class BacktestEngine: """Runs the multi-agent game inside a daily event-driven loop.""" @@ -892,18 +901,48 @@ class BacktestEngine: ], ) + def start_session(self) -> BacktestSession: + """Initialise a new incremental backtest session.""" + + return BacktestSession( + state=PortfolioState(), + result=BacktestResult(), + current_date=self.cfg.start_date, + ) + + def step_session( + self, + session: BacktestSession, + decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None, + ) -> Tuple[Iterable[Dict[str, Any]], bool]: + """Advance the session by a single trade date. + + Returns ``(records, done)`` where ``records`` is the raw output of + :meth:`simulate_day` and ``done`` indicates whether the session + reached the end date after this step. + """ + + if session.current_date > self.cfg.end_date: + return [], True + + trade_date = session.current_date + records = self.simulate_day(trade_date, session.state, decision_callback) + self._apply_portfolio_updates(trade_date, session.state, records, session.result) + session.current_date = date.fromordinal(trade_date.toordinal() + 1) + done = session.current_date > self.cfg.end_date + return records, done + def run( self, decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None, ) -> BacktestResult: - state = PortfolioState() - result = BacktestResult() - current = self.cfg.start_date - while current <= self.cfg.end_date: - records = self.simulate_day(current, state, decision_callback) - self._apply_portfolio_updates(current, state, records, result) - current = date.fromordinal(current.toordinal() + 1) - return result + session = self.start_session() + if session.current_date > self.cfg.end_date: + return session.result + + while session.current_date <= self.cfg.end_date: + self.step_session(session, decision_callback) + return session.result def run_backtest( diff --git a/app/backtest/optimizer.py b/app/backtest/optimizer.py index 3e7bde9..3fa3c14 100644 --- a/app/backtest/optimizer.py +++ b/app/backtest/optimizer.py @@ -71,7 +71,14 @@ class EpsilonGreedyBandit: for episode in range(1, self.config.episodes + 1): action = self._select_action() self.env.reset() - obs, reward, done, info = self.env.step(action) + done = False + cumulative_reward = 0.0 + obs = {} + info: Dict[str, Any] = {} + while not done: + obs, reward, done, info = self.env.step(action) + cumulative_reward += reward + metrics = self.env.last_metrics if metrics is None: raise RuntimeError("DecisionEnv did not populate last_metrics") @@ -79,7 +86,7 @@ class EpsilonGreedyBandit: old_estimate = self._value_estimates.get(key, 0.0) count = self._counts.get(key, 0) + 1 self._counts[key] = count - self._value_estimates[key] = old_estimate + (reward - old_estimate) / count + self._value_estimates[key] = old_estimate + (cumulative_reward - old_estimate) / count action_payload = self._raw_action_mapping(action) resolved_action = self._resolved_action_mapping(action) @@ -93,7 +100,7 @@ class EpsilonGreedyBandit: experiment_id=self.config.experiment_id, strategy=self.config.strategy, action=action_payload, - reward=reward, + reward=cumulative_reward, metrics=metrics_payload, weights=info.get("weights"), ) @@ -103,7 +110,7 @@ class EpsilonGreedyBandit: episode_record = BanditEpisode( action=action_payload, resolved_action=resolved_action, - reward=reward, + reward=cumulative_reward, metrics=metrics, observation=obs, weights=info.get("weights"), @@ -113,7 +120,7 @@ class EpsilonGreedyBandit: LOGGER.info( "Bandit episode=%s reward=%.4f action=%s", episode, - reward, + cumulative_reward, action_payload, extra=LOG_EXTRA, ) diff --git a/app/rl/__init__.py b/app/rl/__init__.py new file mode 100644 index 0000000..a0ae04c --- /dev/null +++ b/app/rl/__init__.py @@ -0,0 +1,11 @@ +"""Reinforcement learning utilities for DecisionEnv.""" + +from .adapters import DecisionEnvAdapter +from .ppo import PPOConfig, PPOTrainer, train_ppo + +__all__ = [ + "DecisionEnvAdapter", + "PPOConfig", + "PPOTrainer", + "train_ppo", +] diff --git a/app/rl/adapters.py b/app/rl/adapters.py new file mode 100644 index 0000000..187c6e7 --- /dev/null +++ b/app/rl/adapters.py @@ -0,0 +1,57 @@ +"""Environment adapters bridging DecisionEnv to tensor-friendly interfaces.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Iterable, List, Mapping, Sequence, Tuple + +import numpy as np + +from app.backtest.decision_env import DecisionEnv + + +@dataclass +class DecisionEnvAdapter: + """Wraps :class:`DecisionEnv` to emit numpy arrays for RL algorithms.""" + + env: DecisionEnv + observation_keys: Sequence[str] | None = None + + def __post_init__(self) -> None: + if self.observation_keys is None: + reset_obs = self.env.reset() + # Exclude bookkeeping fields not useful for learning policy values + exclude = {"episode"} + self._keys = [key for key in sorted(reset_obs.keys()) if key not in exclude] + self._last_reset_obs = reset_obs + else: + self._keys = list(self.observation_keys) + self._last_reset_obs = None + + @property + def action_dim(self) -> int: + return self.env.action_dim + + @property + def observation_dim(self) -> int: + return len(self._keys) + + def reset(self) -> Tuple[np.ndarray, Dict[str, float]]: + raw = self.env.reset() + self._last_reset_obs = raw + return self._to_array(raw), raw + + def step( + self, action: Sequence[float] + ) -> Tuple[np.ndarray, float, bool, Mapping[str, object], Mapping[str, float]]: + obs_dict, reward, done, info = self.env.step(action) + return self._to_array(obs_dict), reward, done, info, obs_dict + + def _to_array(self, payload: Mapping[str, float]) -> np.ndarray: + buffer = np.zeros(len(self._keys), dtype=np.float32) + for idx, key in enumerate(self._keys): + value = payload.get(key) + buffer[idx] = float(value) if value is not None else 0.0 + return buffer + + def keys(self) -> List[str]: + return list(self._keys) diff --git a/app/rl/ppo.py b/app/rl/ppo.py new file mode 100644 index 0000000..9198bdc --- /dev/null +++ b/app/rl/ppo.py @@ -0,0 +1,265 @@ +"""Lightweight PPO trainer tailored for DecisionEnv.""" +from __future__ import annotations + +import math +from dataclasses import dataclass, field +from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple + +import numpy as np +import torch +from torch import nn +from torch.distributions import Beta + +from .adapters import DecisionEnvAdapter + + +def _init_layer(layer: nn.Module, std: float = 1.0) -> nn.Module: + if isinstance(layer, nn.Linear): + nn.init.orthogonal_(layer.weight, gain=std) + nn.init.zeros_(layer.bias) + return layer + + +class ActorNetwork(nn.Module): + def __init__(self, obs_dim: int, action_dim: int, hidden_sizes: Sequence[int]) -> None: + super().__init__() + layers: List[nn.Module] = [] + last_dim = obs_dim + for size in hidden_sizes: + layers.append(_init_layer(nn.Linear(last_dim, size), std=math.sqrt(2))) + layers.append(nn.Tanh()) + last_dim = size + self.body = nn.Sequential(*layers) + self.alpha_head = _init_layer(nn.Linear(last_dim, action_dim), std=0.01) + self.beta_head = _init_layer(nn.Linear(last_dim, action_dim), std=0.01) + + def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + hidden = self.body(obs) + alpha = torch.nn.functional.softplus(self.alpha_head(hidden)) + 1.0 + beta = torch.nn.functional.softplus(self.beta_head(hidden)) + 1.0 + return alpha, beta + + +class CriticNetwork(nn.Module): + def __init__(self, obs_dim: int, hidden_sizes: Sequence[int]) -> None: + super().__init__() + layers: List[nn.Module] = [] + last_dim = obs_dim + for size in hidden_sizes: + layers.append(_init_layer(nn.Linear(last_dim, size), std=math.sqrt(2))) + layers.append(nn.Tanh()) + last_dim = size + layers.append(_init_layer(nn.Linear(last_dim, 1), std=1.0)) + self.model = nn.Sequential(*layers) + + def forward(self, obs: torch.Tensor) -> torch.Tensor: + return self.model(obs).squeeze(-1) + + +@dataclass +class PPOConfig: + total_timesteps: int = 4096 + rollout_steps: int = 256 + gamma: float = 0.99 + gae_lambda: float = 0.95 + clip_range: float = 0.2 + policy_lr: float = 3e-4 + value_lr: float = 3e-4 + epochs: int = 8 + minibatch_size: int = 128 + entropy_coef: float = 0.01 + value_coef: float = 0.5 + max_grad_norm: float = 0.5 + hidden_sizes: Sequence[int] = (128, 128) + device: str = "cpu" + seed: Optional[int] = None + + +@dataclass +class TrainingSummary: + timesteps: int + episode_rewards: List[float] = field(default_factory=list) + episode_lengths: List[int] = field(default_factory=list) + diagnostics: List[Dict[str, float]] = field(default_factory=list) + + +class RolloutBuffer: + def __init__(self, size: int, obs_dim: int, action_dim: int, device: torch.device) -> None: + self.size = size + self.device = device + self.obs = torch.zeros((size, obs_dim), dtype=torch.float32, device=device) + self.actions = torch.zeros((size, action_dim), dtype=torch.float32, device=device) + self.log_probs = torch.zeros(size, dtype=torch.float32, device=device) + self.rewards = torch.zeros(size, dtype=torch.float32, device=device) + self.dones = torch.zeros(size, dtype=torch.float32, device=device) + self.values = torch.zeros(size, dtype=torch.float32, device=device) + self.advantages = torch.zeros(size, dtype=torch.float32, device=device) + self.returns = torch.zeros(size, dtype=torch.float32, device=device) + self._pos = 0 + + def add( + self, + obs: torch.Tensor, + action: torch.Tensor, + log_prob: torch.Tensor, + reward: float, + done: bool, + value: torch.Tensor, + ) -> None: + if self._pos >= self.size: + raise RuntimeError("rollout buffer overflow; check rollout_steps") + self.obs[self._pos].copy_(obs) + self.actions[self._pos].copy_(action) + self.log_probs[self._pos] = log_prob + self.rewards[self._pos] = reward + self.dones[self._pos] = float(done) + self.values[self._pos] = value + self._pos += 1 + + def finish(self, last_value: float, gamma: float, gae_lambda: float) -> None: + last_advantage = 0.0 + for idx in reversed(range(self._pos)): + mask = 1.0 - float(self.dones[idx]) + value = float(self.values[idx]) + delta = float(self.rewards[idx]) + gamma * last_value * mask - value + last_advantage = delta + gamma * gae_lambda * mask * last_advantage + self.advantages[idx] = last_advantage + self.returns[idx] = last_advantage + value + last_value = value + + if self._pos: + adv_view = self.advantages[: self._pos] + adv_mean = adv_view.mean() + adv_std = adv_view.std(unbiased=False) + 1e-8 + adv_view.sub_(adv_mean).div_(adv_std) + + def get_minibatches(self, batch_size: int) -> Iterable[Tuple[torch.Tensor, ...]]: + if self._pos == 0: + return [] + indices = torch.randperm(self._pos, device=self.device) + for start in range(0, self._pos, batch_size): + end = min(start + batch_size, self._pos) + batch_idx = indices[start:end] + yield ( + self.obs[batch_idx], + self.actions[batch_idx], + self.log_probs[batch_idx], + self.advantages[batch_idx], + self.returns[batch_idx], + self.values[batch_idx], + ) + + def reset(self) -> None: + self._pos = 0 + + +class PPOTrainer: + def __init__(self, adapter: DecisionEnvAdapter, config: PPOConfig) -> None: + self.adapter = adapter + self.config = config + device = torch.device(config.device) + obs_dim = adapter.observation_dim + action_dim = adapter.action_dim + self.actor = ActorNetwork(obs_dim, action_dim, config.hidden_sizes).to(device) + self.critic = CriticNetwork(obs_dim, config.hidden_sizes).to(device) + self.policy_optimizer = torch.optim.Adam(self.actor.parameters(), lr=config.policy_lr) + self.value_optimizer = torch.optim.Adam(self.critic.parameters(), lr=config.value_lr) + self.device = device + if config.seed is not None: + torch.manual_seed(config.seed) + np.random.seed(config.seed) + + def train(self) -> TrainingSummary: + cfg = self.config + obs_array, _ = self.adapter.reset() + obs = torch.from_numpy(obs_array).to(self.device) + rollout = RolloutBuffer(cfg.rollout_steps, self.adapter.observation_dim, self.adapter.action_dim, self.device) + timesteps = 0 + episode_rewards: List[float] = [] + episode_lengths: List[int] = [] + diagnostics: List[Dict[str, float]] = [] + current_return = 0.0 + current_length = 0 + + while timesteps < cfg.total_timesteps: + rollout.reset() + steps_to_collect = min(cfg.rollout_steps, cfg.total_timesteps - timesteps) + for _ in range(steps_to_collect): + with torch.no_grad(): + alpha, beta = self.actor(obs.unsqueeze(0)) + dist = Beta(alpha, beta) + action = dist.rsample().squeeze(0) + log_prob = dist.log_prob(action).sum() + value = self.critic(obs.unsqueeze(0)).squeeze(0) + action_np = action.cpu().numpy() + next_obs_array, reward, done, info, _ = self.adapter.step(action_np) + next_obs = torch.from_numpy(next_obs_array).to(self.device) + + rollout.add(obs, action, log_prob, reward, done, value) + timesteps += 1 + current_return += reward + current_length += 1 + + if done: + episode_rewards.append(current_return) + episode_lengths.append(current_length) + current_return = 0.0 + current_length = 0 + next_obs_array, _ = self.adapter.reset() + next_obs = torch.from_numpy(next_obs_array).to(self.device) + + obs = next_obs + + if timesteps >= cfg.total_timesteps or rollout._pos >= steps_to_collect: + break + + with torch.no_grad(): + next_value = self.critic(obs.unsqueeze(0)).squeeze(0).item() + rollout.finish(last_value=next_value, gamma=cfg.gamma, gae_lambda=cfg.gae_lambda) + + for _ in range(cfg.epochs): + for (mb_obs, mb_actions, mb_log_probs, mb_adv, mb_returns, _) in rollout.get_minibatches( + cfg.minibatch_size + ): + alpha, beta = self.actor(mb_obs) + dist = Beta(alpha, beta) + new_log_probs = dist.log_prob(mb_actions).sum(dim=-1) + entropy = dist.entropy().sum(dim=-1) + ratios = torch.exp(new_log_probs - mb_log_probs) + surrogate1 = ratios * mb_adv + surrogate2 = torch.clamp(ratios, 1.0 - cfg.clip_range, 1.0 + cfg.clip_range) * mb_adv + policy_loss = -torch.min(surrogate1, surrogate2).mean() - cfg.entropy_coef * entropy.mean() + + self.policy_optimizer.zero_grad() + policy_loss.backward() + nn.utils.clip_grad_norm_(self.actor.parameters(), cfg.max_grad_norm) + self.policy_optimizer.step() + + values = self.critic(mb_obs) + value_loss = torch.nn.functional.mse_loss(values, mb_returns) + self.value_optimizer.zero_grad() + value_loss.backward() + nn.utils.clip_grad_norm_(self.critic.parameters(), cfg.max_grad_norm) + self.value_optimizer.step() + + diagnostics.append( + { + "policy_loss": float(policy_loss.detach().cpu()), + "value_loss": float(value_loss.detach().cpu()), + "entropy": float(entropy.mean().detach().cpu()), + } + ) + + return TrainingSummary( + timesteps=timesteps, + episode_rewards=episode_rewards, + episode_lengths=episode_lengths, + diagnostics=diagnostics, + ) + + +def train_ppo(adapter: DecisionEnvAdapter, config: PPOConfig) -> TrainingSummary: + """Convenience helper to run PPO training with defaults.""" + + trainer = PPOTrainer(adapter, config) + return trainer.train() diff --git a/docs/TODO.md b/docs/TODO.md index 4670b95..bebfb6e 100644 --- a/docs/TODO.md +++ b/docs/TODO.md @@ -20,10 +20,21 @@ ## 3. 决策优化与强化学习 - ✅ 扩展 `DecisionEnv` 的动作空间(提示版本、部门温度、function 调用策略等),不仅限于代理权重调节。 - 引入 Bandit / 贝叶斯优化或 RL 算法探索动作空间,并将 `portfolio_snapshots`、`portfolio_trades` 指标纳入奖励约束。 +- 将 `DecisionEnv` 改造为多步 episode,逐日输出状态(行情特征、持仓、风险事件)与动作,充分利用历史序列训练强化学习策略。 +- ✅ 基于多步环境接入 PPO / SAC 等连续动作算法,结合收益、回撤、成交成本设定奖励与约束,提升收益最大化的稳定性。 +- 在整段回测层面引入并行贝叶斯优化(TPE/BOHB)或其他全局搜索,为强化学习提供高收益初始权重与参数候选。 +- 建立离线验证与滚动前向测试流程,对新策略做回测 vs. 实盘对照,防止收益最大化策略过拟合历史数据。 - 构建实时持仓/成交数据写入链路,使线上监控与离线调参共用同一数据源。 - 借鉴 TradingAgents-CN 的做法:拆分环境与策略、提供训练脚本/配置,并输出丰富的评估指标(如 Sharpe、Sortino、基准对比)。 - 完善 `BacktestEngine` 的成交撮合、风险阈值与指标输出,让回测信号直接对接执行端,形成无人值守的自动闭环。 +### 3.1 实施步骤(建议顺序) +1. 环境重构:扩展 `DecisionEnv` 支持逐日状态/动作/奖励,完善 `BacktestEngine` 的状态保存与恢复接口,并补充必要的数据库读写钩子。 +2. 训练基线:实现基于多步环境的 PPO(或 SAC)训练脚本,定义网络结构、奖励项(收益/回撤/成交成本)和超参,先在小规模标的上验证收敛。 +3. 全局搜索:在整段回测模式下并行运行 TPE/BOHB 等贝叶斯优化,产出高收益参数作为 RL 的初始化权重或候选策略。 +4. 验证闭环:搭建滚动前向测试流水线,自动记录训练策略的回测表现与准实时对照,接入监控面板并输出风险/收益指标。 +5. 上线准备:结合实时持仓/成交链路,完善回滚与安全阈值机制,准备 A/B 或影子跟投实验,确认收益最大化策略的稳健性。 + ## 4. 测试与验证 - 补充部门上下文构造、多模型调用、回测指标生成等核心路径的单元 / 集成测试。 - 建立决策流程的回归测试用例,确保提示模板或配置调整后行为可复现。 diff --git a/requirements.txt b/requirements.txt index 333df38..624bf93 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ pytest>=7.0 feedparser>=6.0 arch>=6.1.0 scipy>=1.11.0 +torch>=2.3.0 diff --git a/scripts/train_ppo.py b/scripts/train_ppo.py new file mode 100644 index 0000000..22eef4a --- /dev/null +++ b/scripts/train_ppo.py @@ -0,0 +1,133 @@ +"""Command-line entrypoint for PPO training on DecisionEnv.""" +from __future__ import annotations + +import argparse +import json +from datetime import datetime +from pathlib import Path +from typing import List + +import numpy as np + +from app.agents.registry import default_agents +from app.backtest.decision_env import DecisionEnv, ParameterSpec +from app.backtest.engine import BtConfig +from app.rl import DecisionEnvAdapter, PPOConfig, train_ppo +from app.ui.shared import default_backtest_range +from app.utils.config import get_config + + +def _parse_universe(raw: str) -> List[str]: + return [item.strip() for item in raw.split(",") if item.strip()] + + +def build_env(args: argparse.Namespace) -> DecisionEnvAdapter: + app_cfg = get_config() + start = datetime.strptime(args.start_date, "%Y-%m-%d").date() + end = datetime.strptime(args.end_date, "%Y-%m-%d").date() + universe = _parse_universe(args.universe) + if not universe: + raise ValueError("universe must contain at least one ts_code") + + agents = default_agents() + baseline_weights = app_cfg.agent_weights.as_dict() + for agent in agents: + baseline_weights.setdefault(agent.name, 1.0) + + specs: List[ParameterSpec] = [] + for name in sorted(baseline_weights): + specs.append( + ParameterSpec( + name=f"weight_{name}", + target=f"agent_weights.{name}", + minimum=args.weight_min, + maximum=args.weight_max, + ) + ) + + bt_cfg = BtConfig( + id=args.experiment_id, + name=f"PPO-{args.experiment_id}", + start_date=start, + end_date=end, + universe=universe, + params={ + "target": args.target, + "stop": args.stop, + "hold_days": args.hold_days, + }, + method=app_cfg.decision_method, + ) + env = DecisionEnv( + bt_config=bt_cfg, + parameter_specs=specs, + baseline_weights=baseline_weights, + disable_departments=args.disable_departments, + ) + return DecisionEnvAdapter(env) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Train PPO policy on DecisionEnv") + default_start, default_end = default_backtest_range(window_days=60) + parser.add_argument("--start-date", default=str(default_start)) + parser.add_argument("--end-date", default=str(default_end)) + parser.add_argument("--universe", default="000001.SZ") + parser.add_argument("--experiment-id", default=f"ppo_{datetime.now().strftime('%Y%m%d_%H%M%S')}") + parser.add_argument("--hold-days", type=int, default=10) + parser.add_argument("--target", type=float, default=0.035) + parser.add_argument("--stop", type=float, default=-0.015) + parser.add_argument("--total-timesteps", type=int, default=4096) + parser.add_argument("--rollout-steps", type=int, default=256) + parser.add_argument("--epochs", type=int, default=8) + parser.add_argument("--minibatch-size", type=int, default=128) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--clip-range", type=float, default=0.2) + parser.add_argument("--policy-lr", type=float, default=3e-4) + parser.add_argument("--value-lr", type=float, default=3e-4) + parser.add_argument("--entropy-coef", type=float, default=0.01) + parser.add_argument("--value-coef", type=float, default=0.5) + parser.add_argument("--max-grad-norm", type=float, default=0.5) + parser.add_argument("--hidden-sizes", default="128,128") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--weight-min", type=float, default=0.0) + parser.add_argument("--weight-max", type=float, default=1.5) + parser.add_argument("--disable-departments", action="store_true") + parser.add_argument("--output", type=Path, default=Path("ppo_training_summary.json")) + + args = parser.parse_args() + hidden_sizes = tuple(int(x) for x in args.hidden_sizes.split(",") if x.strip()) + adapter = build_env(args) + + config = PPOConfig( + total_timesteps=args.total_timesteps, + rollout_steps=args.rollout_steps, + gamma=args.gamma, + gae_lambda=args.gae_lambda, + clip_range=args.clip_range, + policy_lr=args.policy_lr, + value_lr=args.value_lr, + epochs=args.epochs, + minibatch_size=args.minibatch_size, + entropy_coef=args.entropy_coef, + value_coef=args.value_coef, + max_grad_norm=args.max_grad_norm, + hidden_sizes=hidden_sizes, + seed=args.seed, + ) + + summary = train_ppo(adapter, config) + payload = { + "timesteps": summary.timesteps, + "episode_rewards": summary.episode_rewards, + "episode_lengths": summary.episode_lengths, + "diagnostics_tail": summary.diagnostics[-10:], + "observation_keys": adapter.keys(), + } + args.output.write_text(json.dumps(payload, indent=2, ensure_ascii=False)) + print(f"Training finished. Summary written to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/tests/test_decision_env.py b/tests/test_decision_env.py index d87d1f7..5a551ed 100644 --- a/tests/test_decision_env.py +++ b/tests/test_decision_env.py @@ -6,7 +6,7 @@ from datetime import date import pytest from app.backtest.decision_env import DecisionEnv, EpisodeMetrics, ParameterSpec -from app.backtest.engine import BacktestResult, BtConfig +from app.backtest.engine import BacktestResult, BacktestSession, BtConfig, PortfolioState from app.utils.config import DepartmentSettings, LLMConfig, LLMEndpoint @@ -60,42 +60,50 @@ class _StubEngine: self.department_manager = _StubManager() _StubEngine.last_instance = self - def run(self) -> BacktestResult: - result = BacktestResult() - result.nav_series = [ - { - "trade_date": "2025-01-10", - "nav": 102.0, - "cash": 50.0, - "market_value": 52.0, - "realized_pnl": 1.0, - "unrealized_pnl": 1.0, - "turnover": 20000.0, - "turnover_ratio": 0.2, - } - ] - result.trades = [ - { - "trade_date": "2025-01-10", - "ts_code": "000001.SZ", - "action": "buy", - "quantity": 100.0, - "price": 100.0, - "value": 10000.0, - "fee": 5.0, - } - ] - result.risk_events = [ - { - "trade_date": "2025-01-10", - "ts_code": "000002.SZ", - "reason": "limit_up", - "action": "buy_l", - "confidence": 0.7, - "target_weight": 0.2, - } - ] - return result + def start_session(self) -> BacktestSession: + return BacktestSession( + state=PortfolioState(), + result=BacktestResult(), + current_date=self.cfg.start_date, + ) + + def step_session(self, session: BacktestSession, *_args, **_kwargs): + if session.current_date > self.cfg.end_date: + return [], True + + payload_nav = { + "trade_date": session.current_date.isoformat(), + "nav": 102.0, + "cash": 50.0, + "market_value": 52.0, + "realized_pnl": 1.0, + "unrealized_pnl": 1.0, + "turnover": 20000.0, + "turnover_ratio": 0.2, + } + payload_trade = { + "trade_date": session.current_date.isoformat(), + "ts_code": "000001.SZ", + "action": "buy", + "quantity": 100.0, + "price": 100.0, + "value": 10000.0, + "fee": 5.0, + } + payload_risk = { + "trade_date": session.current_date.isoformat(), + "ts_code": "000002.SZ", + "reason": "limit_up", + "action": "buy_l", + "confidence": 0.7, + "target_weight": 0.2, + } + session.result.nav_series.append(payload_nav) + session.result.trades.append(payload_trade) + session.result.risk_events.append(payload_risk) + session.current_date = date.fromordinal(session.current_date.toordinal() + 1) + done = session.current_date > self.cfg.end_date + return [payload_nav], done _StubEngine.last_instance: _StubEngine | None = None @@ -117,6 +125,7 @@ def test_decision_env_returns_risk_metrics(monkeypatch): monkeypatch.setattr(DecisionEnv, "_clear_portfolio_records", lambda self: None) monkeypatch.setattr(DecisionEnv, "_fetch_portfolio_records", lambda self: ([], [])) + env.reset() obs, reward, done, info = env.step([0.8]) assert done is True @@ -187,6 +196,7 @@ def test_decision_env_department_controls(monkeypatch): monkeypatch.setattr(DecisionEnv, "_clear_portfolio_records", lambda self: None) monkeypatch.setattr(DecisionEnv, "_fetch_portfolio_records", lambda self: ([], [])) + env.reset() obs, reward, done, info = env.step([0.3, 1.0, 0.75, 0.0, 1.0]) assert done is True @@ -198,14 +208,33 @@ def test_decision_env_department_controls(monkeypatch): assert momentum_ctrl["prompt"] == "aggressive" assert momentum_ctrl["temperature"] == pytest.approx(0.7, abs=1e-6) assert momentum_ctrl["tool_choice"] == "none" - assert momentum_ctrl["max_rounds"] == 5 - assert env.last_department_controls == controls - engine = _StubEngine.last_instance - assert engine is not None - agent = engine.department_manager.agents["momentum"] - assert agent.settings.prompt == "aggressive" - assert agent.settings.llm.primary.temperature == pytest.approx(0.7, abs=1e-6) - assert agent.tool_choice == "none" - assert agent.max_rounds == 5 +def test_decision_env_multistep_session(monkeypatch): + cfg = BtConfig( + id="stub", + name="stub", + start_date=date(2025, 1, 10), + end_date=date(2025, 1, 12), + universe=["000001.SZ"], + params={}, + ) + specs = [ParameterSpec(name="w_mom", target="agent_weights.A_mom", minimum=0.0, maximum=1.0)] + env = DecisionEnv(bt_config=cfg, parameter_specs=specs, baseline_weights={"A_mom": 0.5}) + + monkeypatch.setattr("app.backtest.decision_env.BacktestEngine", _StubEngine) + monkeypatch.setattr(DecisionEnv, "_clear_portfolio_records", lambda self: None) + monkeypatch.setattr(DecisionEnv, "_fetch_portfolio_records", lambda self: ([], [])) + + env.reset() + obs, reward, done, info = env.step([0.6]) + assert done is False + assert obs["day_index"] == pytest.approx(1.0) + + obs2, reward2, done2, _ = env.step([0.6]) + assert done2 is False + assert obs2["day_index"] == pytest.approx(2.0) + + obs3, reward3, done3, _ = env.step([0.6]) + assert done3 is True + assert obs3["day_index"] == pytest.approx(3.0) diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py new file mode 100644 index 0000000..879eda6 --- /dev/null +++ b/tests/test_ppo_trainer.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import math + +from app.rl.adapters import DecisionEnvAdapter +from app.rl.ppo import PPOConfig, train_ppo + + +class _DummyDecisionEnv: + action_dim = 1 + + def __init__(self) -> None: + self._step = 0 + self._episode = 0 + + def reset(self): + self._step = 0 + self._episode += 1 + return { + "day_index": 0.0, + "total_return": 0.0, + "max_drawdown": 0.0, + "volatility": 0.0, + "turnover": 0.0, + "sharpe_like": 0.0, + "trade_count": 0.0, + "risk_count": 0.0, + "nav": 1.0, + "cash": 1.0, + "market_value": 0.0, + "done": 0.0, + } + + def step(self, action): + value = float(action[0]) + reward = 1.0 - abs(value - 0.8) + self._step += 1 + done = self._step >= 3 + obs = { + "day_index": float(self._step), + "total_return": reward, + "max_drawdown": 0.1, + "volatility": 0.05, + "turnover": 0.1, + "sharpe_like": reward / 0.05, + "trade_count": float(self._step), + "risk_count": 0.0, + "nav": 1.0 + 0.01 * self._step, + "cash": 1.0, + "market_value": 0.0, + "done": 1.0 if done else 0.0, + } + info = {} + return obs, reward, done, info + + +def test_train_ppo_completes_with_dummy_env(): + adapter = DecisionEnvAdapter(_DummyDecisionEnv()) + config = PPOConfig( + total_timesteps=64, + rollout_steps=16, + epochs=2, + minibatch_size=8, + hidden_sizes=(32, 32), + seed=123, + ) + summary = train_ppo(adapter, config) + + assert summary.timesteps == config.total_timesteps + assert summary.episode_rewards + assert not math.isnan(summary.episode_rewards[-1]) + assert summary.diagnostics