56 lines
1.7 KiB
Python
56 lines
1.7 KiB
Python
"""Reinforcement learning utilities for DecisionEnv."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional, Sequence
|
|
|
|
from .adapters import DecisionEnvAdapter
|
|
|
|
TORCH_AVAILABLE = True
|
|
|
|
try: # pragma: no cover - exercised via integration
|
|
from .ppo import PPOConfig, PPOTrainer, train_ppo
|
|
except ModuleNotFoundError as exc: # pragma: no cover - optional dependency guard
|
|
if exc.name != "torch":
|
|
raise
|
|
TORCH_AVAILABLE = False
|
|
|
|
@dataclass
|
|
class PPOConfig:
|
|
"""Placeholder PPOConfig used when torch is unavailable."""
|
|
|
|
total_timesteps: int = 0
|
|
rollout_steps: int = 0
|
|
gamma: float = 0.0
|
|
gae_lambda: float = 0.0
|
|
clip_range: float = 0.0
|
|
policy_lr: float = 0.0
|
|
value_lr: float = 0.0
|
|
epochs: int = 0
|
|
minibatch_size: int = 0
|
|
entropy_coef: float = 0.0
|
|
value_coef: float = 0.0
|
|
max_grad_norm: float = 0.0
|
|
hidden_sizes: Sequence[int] = field(default_factory=tuple)
|
|
seed: Optional[int] = None
|
|
|
|
class PPOTrainer: # pragma: no cover - simply raises when used
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
raise ModuleNotFoundError(
|
|
"torch is required for PPO training. Please install torch before using this feature."
|
|
)
|
|
|
|
def train_ppo(*_args, **_kwargs): # pragma: no cover - simply raises when used
|
|
raise ModuleNotFoundError(
|
|
"torch is required for PPO training. Please install torch before using this feature."
|
|
)
|
|
|
|
__all__ = [
|
|
"DecisionEnvAdapter",
|
|
"PPOConfig",
|
|
"PPOTrainer",
|
|
"train_ppo",
|
|
"TORCH_AVAILABLE",
|
|
]
|