llm-quant/app/rl/ppo.py

326 lines
13 KiB
Python

"""Lightweight PPO trainer tailored for DecisionEnv."""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
import numpy as np
import torch
from torch import nn
from torch.distributions import Beta
from app.utils.logging import get_logger
from .adapters import DecisionEnvAdapter
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "rl_ppo"}
def _init_layer(layer: nn.Module, std: float = 1.0) -> nn.Module:
if isinstance(layer, nn.Linear):
nn.init.orthogonal_(layer.weight, gain=std)
nn.init.zeros_(layer.bias)
return layer
class ActorNetwork(nn.Module):
def __init__(self, obs_dim: int, action_dim: int, hidden_sizes: Sequence[int]) -> None:
super().__init__()
layers: List[nn.Module] = []
last_dim = obs_dim
for size in hidden_sizes:
layers.append(_init_layer(nn.Linear(last_dim, size), std=math.sqrt(2)))
layers.append(nn.Tanh())
last_dim = size
self.body = nn.Sequential(*layers)
self.alpha_head = _init_layer(nn.Linear(last_dim, action_dim), std=0.01)
self.beta_head = _init_layer(nn.Linear(last_dim, action_dim), std=0.01)
def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
hidden = self.body(obs)
alpha = torch.nn.functional.softplus(self.alpha_head(hidden)) + 1.0
beta = torch.nn.functional.softplus(self.beta_head(hidden)) + 1.0
return alpha, beta
class CriticNetwork(nn.Module):
def __init__(self, obs_dim: int, hidden_sizes: Sequence[int]) -> None:
super().__init__()
layers: List[nn.Module] = []
last_dim = obs_dim
for size in hidden_sizes:
layers.append(_init_layer(nn.Linear(last_dim, size), std=math.sqrt(2)))
layers.append(nn.Tanh())
last_dim = size
layers.append(_init_layer(nn.Linear(last_dim, 1), std=1.0))
self.model = nn.Sequential(*layers)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
return self.model(obs).squeeze(-1)
@dataclass
class PPOConfig:
total_timesteps: int = 4096
rollout_steps: int = 256
gamma: float = 0.99
gae_lambda: float = 0.95
clip_range: float = 0.2
policy_lr: float = 3e-4
value_lr: float = 3e-4
epochs: int = 8
minibatch_size: int = 128
entropy_coef: float = 0.01
value_coef: float = 0.5
max_grad_norm: float = 0.5
hidden_sizes: Sequence[int] = (128, 128)
device: str = "cpu"
seed: Optional[int] = None
@dataclass
class TrainingSummary:
timesteps: int
episode_rewards: List[float] = field(default_factory=list)
episode_lengths: List[int] = field(default_factory=list)
diagnostics: List[Dict[str, float]] = field(default_factory=list)
class RolloutBuffer:
def __init__(self, size: int, obs_dim: int, action_dim: int, device: torch.device) -> None:
self.size = size
self.device = device
self.obs = torch.zeros((size, obs_dim), dtype=torch.float32, device=device)
self.actions = torch.zeros((size, action_dim), dtype=torch.float32, device=device)
self.log_probs = torch.zeros(size, dtype=torch.float32, device=device)
self.rewards = torch.zeros(size, dtype=torch.float32, device=device)
self.dones = torch.zeros(size, dtype=torch.float32, device=device)
self.values = torch.zeros(size, dtype=torch.float32, device=device)
self.advantages = torch.zeros(size, dtype=torch.float32, device=device)
self.returns = torch.zeros(size, dtype=torch.float32, device=device)
self._pos = 0
def add(
self,
obs: torch.Tensor,
action: torch.Tensor,
log_prob: torch.Tensor,
reward: float,
done: bool,
value: torch.Tensor,
) -> None:
if self._pos >= self.size:
raise RuntimeError("rollout buffer overflow; check rollout_steps")
self.obs[self._pos].copy_(obs)
self.actions[self._pos].copy_(action)
self.log_probs[self._pos] = log_prob
self.rewards[self._pos] = reward
self.dones[self._pos] = float(done)
self.values[self._pos] = value
self._pos += 1
def finish(self, last_value: float, gamma: float, gae_lambda: float) -> None:
last_advantage = 0.0
for idx in reversed(range(self._pos)):
mask = 1.0 - float(self.dones[idx])
value = float(self.values[idx])
delta = float(self.rewards[idx]) + gamma * last_value * mask - value
last_advantage = delta + gamma * gae_lambda * mask * last_advantage
self.advantages[idx] = last_advantage
self.returns[idx] = last_advantage + value
last_value = value
if self._pos:
adv_view = self.advantages[: self._pos]
adv_mean = adv_view.mean()
adv_std = adv_view.std(unbiased=False) + 1e-8
adv_view.sub_(adv_mean).div_(adv_std)
def get_minibatches(self, batch_size: int) -> Iterable[Tuple[torch.Tensor, ...]]:
if self._pos == 0:
return []
indices = torch.randperm(self._pos, device=self.device)
for start in range(0, self._pos, batch_size):
end = min(start + batch_size, self._pos)
batch_idx = indices[start:end]
yield (
self.obs[batch_idx],
self.actions[batch_idx],
self.log_probs[batch_idx],
self.advantages[batch_idx],
self.returns[batch_idx],
self.values[batch_idx],
)
def reset(self) -> None:
self._pos = 0
class PPOTrainer:
def __init__(self, adapter: DecisionEnvAdapter, config: PPOConfig) -> None:
self.adapter = adapter
self.config = config
device = torch.device(config.device)
obs_dim = adapter.observation_dim
action_dim = adapter.action_dim
self.actor = ActorNetwork(obs_dim, action_dim, config.hidden_sizes).to(device)
self.critic = CriticNetwork(obs_dim, config.hidden_sizes).to(device)
self.policy_optimizer = torch.optim.Adam(self.actor.parameters(), lr=config.policy_lr)
self.value_optimizer = torch.optim.Adam(self.critic.parameters(), lr=config.value_lr)
self.device = device
if config.seed is not None:
torch.manual_seed(config.seed)
np.random.seed(config.seed)
LOGGER.info(
"初始化 PPOTrainer obs_dim=%s action_dim=%s total_timesteps=%s rollout=%s device=%s",
obs_dim,
action_dim,
config.total_timesteps,
config.rollout_steps,
config.device,
extra=LOG_EXTRA,
)
def train(self) -> TrainingSummary:
cfg = self.config
obs_array, _ = self.adapter.reset()
obs = torch.from_numpy(obs_array).to(self.device)
rollout = RolloutBuffer(cfg.rollout_steps, self.adapter.observation_dim, self.adapter.action_dim, self.device)
timesteps = 0
episode_rewards: List[float] = []
episode_lengths: List[int] = []
diagnostics: List[Dict[str, float]] = []
current_return = 0.0
current_length = 0
LOGGER.info(
"开始 PPO 训练 total_timesteps=%s rollout_steps=%s epochs=%s minibatch=%s",
cfg.total_timesteps,
cfg.rollout_steps,
cfg.epochs,
cfg.minibatch_size,
extra=LOG_EXTRA,
)
while timesteps < cfg.total_timesteps:
rollout.reset()
steps_to_collect = min(cfg.rollout_steps, cfg.total_timesteps - timesteps)
for _ in range(steps_to_collect):
with torch.no_grad():
alpha, beta = self.actor(obs.unsqueeze(0))
dist = Beta(alpha, beta)
action = dist.rsample().squeeze(0)
log_prob = dist.log_prob(action).sum()
value = self.critic(obs.unsqueeze(0)).squeeze(0)
action_np = action.cpu().numpy()
next_obs_array, reward, done, info, _ = self.adapter.step(action_np)
next_obs = torch.from_numpy(next_obs_array).to(self.device)
rollout.add(obs, action, log_prob, reward, done, value)
timesteps += 1
current_return += reward
current_length += 1
if done:
episode_rewards.append(current_return)
episode_lengths.append(current_length)
LOGGER.info(
"episode 完成 reward=%.4f length=%s episodes=%s timesteps=%s",
episode_rewards[-1],
episode_lengths[-1],
len(episode_rewards),
timesteps,
extra=LOG_EXTRA,
)
current_return = 0.0
current_length = 0
next_obs_array, _ = self.adapter.reset()
next_obs = torch.from_numpy(next_obs_array).to(self.device)
obs = next_obs
if timesteps >= cfg.total_timesteps or rollout._pos >= steps_to_collect:
break
with torch.no_grad():
next_value = self.critic(obs.unsqueeze(0)).squeeze(0).item()
rollout.finish(last_value=next_value, gamma=cfg.gamma, gae_lambda=cfg.gae_lambda)
LOGGER.debug(
"完成样本收集 batch_size=%s timesteps=%s remaining=%s",
rollout._pos,
timesteps,
cfg.total_timesteps - timesteps,
extra=LOG_EXTRA,
)
last_policy_loss = None
last_value_loss = None
last_entropy = None
for _ in range(cfg.epochs):
for (mb_obs, mb_actions, mb_log_probs, mb_adv, mb_returns, _) in rollout.get_minibatches(
cfg.minibatch_size
):
alpha, beta = self.actor(mb_obs)
dist = Beta(alpha, beta)
new_log_probs = dist.log_prob(mb_actions).sum(dim=-1)
entropy = dist.entropy().sum(dim=-1)
ratios = torch.exp(new_log_probs - mb_log_probs)
surrogate1 = ratios * mb_adv
surrogate2 = torch.clamp(ratios, 1.0 - cfg.clip_range, 1.0 + cfg.clip_range) * mb_adv
policy_loss = -torch.min(surrogate1, surrogate2).mean() - cfg.entropy_coef * entropy.mean()
self.policy_optimizer.zero_grad()
policy_loss.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), cfg.max_grad_norm)
self.policy_optimizer.step()
values = self.critic(mb_obs)
value_loss = torch.nn.functional.mse_loss(values, mb_returns)
self.value_optimizer.zero_grad()
value_loss.backward()
nn.utils.clip_grad_norm_(self.critic.parameters(), cfg.max_grad_norm)
self.value_optimizer.step()
last_policy_loss = float(policy_loss.detach().cpu())
last_value_loss = float(value_loss.detach().cpu())
last_entropy = float(entropy.mean().detach().cpu())
diagnostics.append(
{
"policy_loss": float(policy_loss.detach().cpu()),
"value_loss": float(value_loss.detach().cpu()),
"entropy": float(entropy.mean().detach().cpu()),
}
)
LOGGER.info(
"优化轮次完成 timesteps=%s/%s policy_loss=%.4f value_loss=%.4f entropy=%.4f",
timesteps,
cfg.total_timesteps,
last_policy_loss if last_policy_loss is not None else 0.0,
last_value_loss if last_value_loss is not None else 0.0,
last_entropy if last_entropy is not None else 0.0,
extra=LOG_EXTRA,
)
summary = TrainingSummary(
timesteps=timesteps,
episode_rewards=episode_rewards,
episode_lengths=episode_lengths,
diagnostics=diagnostics,
)
LOGGER.info(
"PPO 训练结束 timesteps=%s episodes=%s mean_reward=%.4f",
summary.timesteps,
len(summary.episode_rewards),
float(np.mean(summary.episode_rewards)) if summary.episode_rewards else 0.0,
extra=LOG_EXTRA,
)
return summary
def train_ppo(adapter: DecisionEnvAdapter, config: PPOConfig) -> TrainingSummary:
"""Convenience helper to run PPO training with defaults."""
trainer = PPOTrainer(adapter, config)
return trainer.train()