326 lines
13 KiB
Python
326 lines
13 KiB
Python
"""Lightweight PPO trainer tailored for DecisionEnv."""
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from torch.distributions import Beta
|
|
|
|
from app.utils.logging import get_logger
|
|
|
|
from .adapters import DecisionEnvAdapter
|
|
|
|
LOGGER = get_logger(__name__)
|
|
LOG_EXTRA = {"stage": "rl_ppo"}
|
|
|
|
|
|
def _init_layer(layer: nn.Module, std: float = 1.0) -> nn.Module:
|
|
if isinstance(layer, nn.Linear):
|
|
nn.init.orthogonal_(layer.weight, gain=std)
|
|
nn.init.zeros_(layer.bias)
|
|
return layer
|
|
|
|
|
|
class ActorNetwork(nn.Module):
|
|
def __init__(self, obs_dim: int, action_dim: int, hidden_sizes: Sequence[int]) -> None:
|
|
super().__init__()
|
|
layers: List[nn.Module] = []
|
|
last_dim = obs_dim
|
|
for size in hidden_sizes:
|
|
layers.append(_init_layer(nn.Linear(last_dim, size), std=math.sqrt(2)))
|
|
layers.append(nn.Tanh())
|
|
last_dim = size
|
|
self.body = nn.Sequential(*layers)
|
|
self.alpha_head = _init_layer(nn.Linear(last_dim, action_dim), std=0.01)
|
|
self.beta_head = _init_layer(nn.Linear(last_dim, action_dim), std=0.01)
|
|
|
|
def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
hidden = self.body(obs)
|
|
alpha = torch.nn.functional.softplus(self.alpha_head(hidden)) + 1.0
|
|
beta = torch.nn.functional.softplus(self.beta_head(hidden)) + 1.0
|
|
return alpha, beta
|
|
|
|
|
|
class CriticNetwork(nn.Module):
|
|
def __init__(self, obs_dim: int, hidden_sizes: Sequence[int]) -> None:
|
|
super().__init__()
|
|
layers: List[nn.Module] = []
|
|
last_dim = obs_dim
|
|
for size in hidden_sizes:
|
|
layers.append(_init_layer(nn.Linear(last_dim, size), std=math.sqrt(2)))
|
|
layers.append(nn.Tanh())
|
|
last_dim = size
|
|
layers.append(_init_layer(nn.Linear(last_dim, 1), std=1.0))
|
|
self.model = nn.Sequential(*layers)
|
|
|
|
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
|
return self.model(obs).squeeze(-1)
|
|
|
|
|
|
@dataclass
|
|
class PPOConfig:
|
|
total_timesteps: int = 4096
|
|
rollout_steps: int = 256
|
|
gamma: float = 0.99
|
|
gae_lambda: float = 0.95
|
|
clip_range: float = 0.2
|
|
policy_lr: float = 3e-4
|
|
value_lr: float = 3e-4
|
|
epochs: int = 8
|
|
minibatch_size: int = 128
|
|
entropy_coef: float = 0.01
|
|
value_coef: float = 0.5
|
|
max_grad_norm: float = 0.5
|
|
hidden_sizes: Sequence[int] = (128, 128)
|
|
device: str = "cpu"
|
|
seed: Optional[int] = None
|
|
|
|
|
|
@dataclass
|
|
class TrainingSummary:
|
|
timesteps: int
|
|
episode_rewards: List[float] = field(default_factory=list)
|
|
episode_lengths: List[int] = field(default_factory=list)
|
|
diagnostics: List[Dict[str, float]] = field(default_factory=list)
|
|
|
|
|
|
class RolloutBuffer:
|
|
def __init__(self, size: int, obs_dim: int, action_dim: int, device: torch.device) -> None:
|
|
self.size = size
|
|
self.device = device
|
|
self.obs = torch.zeros((size, obs_dim), dtype=torch.float32, device=device)
|
|
self.actions = torch.zeros((size, action_dim), dtype=torch.float32, device=device)
|
|
self.log_probs = torch.zeros(size, dtype=torch.float32, device=device)
|
|
self.rewards = torch.zeros(size, dtype=torch.float32, device=device)
|
|
self.dones = torch.zeros(size, dtype=torch.float32, device=device)
|
|
self.values = torch.zeros(size, dtype=torch.float32, device=device)
|
|
self.advantages = torch.zeros(size, dtype=torch.float32, device=device)
|
|
self.returns = torch.zeros(size, dtype=torch.float32, device=device)
|
|
self._pos = 0
|
|
|
|
def add(
|
|
self,
|
|
obs: torch.Tensor,
|
|
action: torch.Tensor,
|
|
log_prob: torch.Tensor,
|
|
reward: float,
|
|
done: bool,
|
|
value: torch.Tensor,
|
|
) -> None:
|
|
if self._pos >= self.size:
|
|
raise RuntimeError("rollout buffer overflow; check rollout_steps")
|
|
self.obs[self._pos].copy_(obs)
|
|
self.actions[self._pos].copy_(action)
|
|
self.log_probs[self._pos] = log_prob
|
|
self.rewards[self._pos] = reward
|
|
self.dones[self._pos] = float(done)
|
|
self.values[self._pos] = value
|
|
self._pos += 1
|
|
|
|
def finish(self, last_value: float, gamma: float, gae_lambda: float) -> None:
|
|
last_advantage = 0.0
|
|
for idx in reversed(range(self._pos)):
|
|
mask = 1.0 - float(self.dones[idx])
|
|
value = float(self.values[idx])
|
|
delta = float(self.rewards[idx]) + gamma * last_value * mask - value
|
|
last_advantage = delta + gamma * gae_lambda * mask * last_advantage
|
|
self.advantages[idx] = last_advantage
|
|
self.returns[idx] = last_advantage + value
|
|
last_value = value
|
|
|
|
if self._pos:
|
|
adv_view = self.advantages[: self._pos]
|
|
adv_mean = adv_view.mean()
|
|
adv_std = adv_view.std(unbiased=False) + 1e-8
|
|
adv_view.sub_(adv_mean).div_(adv_std)
|
|
|
|
def get_minibatches(self, batch_size: int) -> Iterable[Tuple[torch.Tensor, ...]]:
|
|
if self._pos == 0:
|
|
return []
|
|
indices = torch.randperm(self._pos, device=self.device)
|
|
for start in range(0, self._pos, batch_size):
|
|
end = min(start + batch_size, self._pos)
|
|
batch_idx = indices[start:end]
|
|
yield (
|
|
self.obs[batch_idx],
|
|
self.actions[batch_idx],
|
|
self.log_probs[batch_idx],
|
|
self.advantages[batch_idx],
|
|
self.returns[batch_idx],
|
|
self.values[batch_idx],
|
|
)
|
|
|
|
def reset(self) -> None:
|
|
self._pos = 0
|
|
|
|
|
|
class PPOTrainer:
|
|
def __init__(self, adapter: DecisionEnvAdapter, config: PPOConfig) -> None:
|
|
self.adapter = adapter
|
|
self.config = config
|
|
device = torch.device(config.device)
|
|
obs_dim = adapter.observation_dim
|
|
action_dim = adapter.action_dim
|
|
self.actor = ActorNetwork(obs_dim, action_dim, config.hidden_sizes).to(device)
|
|
self.critic = CriticNetwork(obs_dim, config.hidden_sizes).to(device)
|
|
self.policy_optimizer = torch.optim.Adam(self.actor.parameters(), lr=config.policy_lr)
|
|
self.value_optimizer = torch.optim.Adam(self.critic.parameters(), lr=config.value_lr)
|
|
self.device = device
|
|
if config.seed is not None:
|
|
torch.manual_seed(config.seed)
|
|
np.random.seed(config.seed)
|
|
LOGGER.info(
|
|
"初始化 PPOTrainer obs_dim=%s action_dim=%s total_timesteps=%s rollout=%s device=%s",
|
|
obs_dim,
|
|
action_dim,
|
|
config.total_timesteps,
|
|
config.rollout_steps,
|
|
config.device,
|
|
extra=LOG_EXTRA,
|
|
)
|
|
|
|
def train(self) -> TrainingSummary:
|
|
cfg = self.config
|
|
obs_array, _ = self.adapter.reset()
|
|
obs = torch.from_numpy(obs_array).to(self.device)
|
|
rollout = RolloutBuffer(cfg.rollout_steps, self.adapter.observation_dim, self.adapter.action_dim, self.device)
|
|
timesteps = 0
|
|
episode_rewards: List[float] = []
|
|
episode_lengths: List[int] = []
|
|
diagnostics: List[Dict[str, float]] = []
|
|
current_return = 0.0
|
|
current_length = 0
|
|
LOGGER.info(
|
|
"开始 PPO 训练 total_timesteps=%s rollout_steps=%s epochs=%s minibatch=%s",
|
|
cfg.total_timesteps,
|
|
cfg.rollout_steps,
|
|
cfg.epochs,
|
|
cfg.minibatch_size,
|
|
extra=LOG_EXTRA,
|
|
)
|
|
|
|
while timesteps < cfg.total_timesteps:
|
|
rollout.reset()
|
|
steps_to_collect = min(cfg.rollout_steps, cfg.total_timesteps - timesteps)
|
|
for _ in range(steps_to_collect):
|
|
with torch.no_grad():
|
|
alpha, beta = self.actor(obs.unsqueeze(0))
|
|
dist = Beta(alpha, beta)
|
|
action = dist.rsample().squeeze(0)
|
|
log_prob = dist.log_prob(action).sum()
|
|
value = self.critic(obs.unsqueeze(0)).squeeze(0)
|
|
action_np = action.cpu().numpy()
|
|
next_obs_array, reward, done, info, _ = self.adapter.step(action_np)
|
|
next_obs = torch.from_numpy(next_obs_array).to(self.device)
|
|
|
|
rollout.add(obs, action, log_prob, reward, done, value)
|
|
timesteps += 1
|
|
current_return += reward
|
|
current_length += 1
|
|
|
|
if done:
|
|
episode_rewards.append(current_return)
|
|
episode_lengths.append(current_length)
|
|
LOGGER.info(
|
|
"episode 完成 reward=%.4f length=%s episodes=%s timesteps=%s",
|
|
episode_rewards[-1],
|
|
episode_lengths[-1],
|
|
len(episode_rewards),
|
|
timesteps,
|
|
extra=LOG_EXTRA,
|
|
)
|
|
current_return = 0.0
|
|
current_length = 0
|
|
next_obs_array, _ = self.adapter.reset()
|
|
next_obs = torch.from_numpy(next_obs_array).to(self.device)
|
|
|
|
obs = next_obs
|
|
|
|
if timesteps >= cfg.total_timesteps or rollout._pos >= steps_to_collect:
|
|
break
|
|
|
|
with torch.no_grad():
|
|
next_value = self.critic(obs.unsqueeze(0)).squeeze(0).item()
|
|
rollout.finish(last_value=next_value, gamma=cfg.gamma, gae_lambda=cfg.gae_lambda)
|
|
LOGGER.debug(
|
|
"完成样本收集 batch_size=%s timesteps=%s remaining=%s",
|
|
rollout._pos,
|
|
timesteps,
|
|
cfg.total_timesteps - timesteps,
|
|
extra=LOG_EXTRA,
|
|
)
|
|
|
|
last_policy_loss = None
|
|
last_value_loss = None
|
|
last_entropy = None
|
|
for _ in range(cfg.epochs):
|
|
for (mb_obs, mb_actions, mb_log_probs, mb_adv, mb_returns, _) in rollout.get_minibatches(
|
|
cfg.minibatch_size
|
|
):
|
|
alpha, beta = self.actor(mb_obs)
|
|
dist = Beta(alpha, beta)
|
|
new_log_probs = dist.log_prob(mb_actions).sum(dim=-1)
|
|
entropy = dist.entropy().sum(dim=-1)
|
|
ratios = torch.exp(new_log_probs - mb_log_probs)
|
|
surrogate1 = ratios * mb_adv
|
|
surrogate2 = torch.clamp(ratios, 1.0 - cfg.clip_range, 1.0 + cfg.clip_range) * mb_adv
|
|
policy_loss = -torch.min(surrogate1, surrogate2).mean() - cfg.entropy_coef * entropy.mean()
|
|
|
|
self.policy_optimizer.zero_grad()
|
|
policy_loss.backward()
|
|
nn.utils.clip_grad_norm_(self.actor.parameters(), cfg.max_grad_norm)
|
|
self.policy_optimizer.step()
|
|
|
|
values = self.critic(mb_obs)
|
|
value_loss = torch.nn.functional.mse_loss(values, mb_returns)
|
|
self.value_optimizer.zero_grad()
|
|
value_loss.backward()
|
|
nn.utils.clip_grad_norm_(self.critic.parameters(), cfg.max_grad_norm)
|
|
self.value_optimizer.step()
|
|
last_policy_loss = float(policy_loss.detach().cpu())
|
|
last_value_loss = float(value_loss.detach().cpu())
|
|
last_entropy = float(entropy.mean().detach().cpu())
|
|
|
|
diagnostics.append(
|
|
{
|
|
"policy_loss": float(policy_loss.detach().cpu()),
|
|
"value_loss": float(value_loss.detach().cpu()),
|
|
"entropy": float(entropy.mean().detach().cpu()),
|
|
}
|
|
)
|
|
LOGGER.info(
|
|
"优化轮次完成 timesteps=%s/%s policy_loss=%.4f value_loss=%.4f entropy=%.4f",
|
|
timesteps,
|
|
cfg.total_timesteps,
|
|
last_policy_loss if last_policy_loss is not None else 0.0,
|
|
last_value_loss if last_value_loss is not None else 0.0,
|
|
last_entropy if last_entropy is not None else 0.0,
|
|
extra=LOG_EXTRA,
|
|
)
|
|
|
|
summary = TrainingSummary(
|
|
timesteps=timesteps,
|
|
episode_rewards=episode_rewards,
|
|
episode_lengths=episode_lengths,
|
|
diagnostics=diagnostics,
|
|
)
|
|
LOGGER.info(
|
|
"PPO 训练结束 timesteps=%s episodes=%s mean_reward=%.4f",
|
|
summary.timesteps,
|
|
len(summary.episode_rewards),
|
|
float(np.mean(summary.episode_rewards)) if summary.episode_rewards else 0.0,
|
|
extra=LOG_EXTRA,
|
|
)
|
|
return summary
|
|
|
|
|
|
def train_ppo(adapter: DecisionEnvAdapter, config: PPOConfig) -> TrainingSummary:
|
|
"""Convenience helper to run PPO training with defaults."""
|
|
|
|
trainer = PPOTrainer(adapter, config)
|
|
return trainer.train()
|