llm-quant/app/rl/__init__.py

56 lines
1.7 KiB
Python

"""Reinforcement learning utilities for DecisionEnv."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Sequence
from .adapters import DecisionEnvAdapter
TORCH_AVAILABLE = True
try: # pragma: no cover - exercised via integration
from .ppo import PPOConfig, PPOTrainer, train_ppo
except ModuleNotFoundError as exc: # pragma: no cover - optional dependency guard
if exc.name != "torch":
raise
TORCH_AVAILABLE = False
@dataclass
class PPOConfig:
"""Placeholder PPOConfig used when torch is unavailable."""
total_timesteps: int = 0
rollout_steps: int = 0
gamma: float = 0.0
gae_lambda: float = 0.0
clip_range: float = 0.0
policy_lr: float = 0.0
value_lr: float = 0.0
epochs: int = 0
minibatch_size: int = 0
entropy_coef: float = 0.0
value_coef: float = 0.0
max_grad_norm: float = 0.0
hidden_sizes: Sequence[int] = field(default_factory=tuple)
seed: Optional[int] = None
class PPOTrainer: # pragma: no cover - simply raises when used
def __init__(self, *args, **kwargs) -> None:
raise ModuleNotFoundError(
"torch is required for PPO training. Please install torch before using this feature."
)
def train_ppo(*_args, **_kwargs): # pragma: no cover - simply raises when used
raise ModuleNotFoundError(
"torch is required for PPO training. Please install torch before using this feature."
)
__all__ = [
"DecisionEnvAdapter",
"PPOConfig",
"PPOTrainer",
"train_ppo",
"TORCH_AVAILABLE",
]