update
This commit is contained in:
parent
dc2d82f685
commit
6d3afcf555
@ -7,8 +7,9 @@ import copy
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
from .engine import BacktestEngine, BacktestResult, BtConfig
|
||||
from app.agents.game import Decision
|
||||
from datetime import date
|
||||
|
||||
from .engine import BacktestEngine, BacktestResult, BacktestSession, BtConfig
|
||||
from app.agents.registry import weight_map
|
||||
from app.utils.db import db_session
|
||||
from app.utils.logging import get_logger
|
||||
@ -82,6 +83,10 @@ class DecisionEnv:
|
||||
self._last_department_controls: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
self._episode = 0
|
||||
self._disable_departments = bool(disable_departments)
|
||||
self._engine: Optional[BacktestEngine] = None
|
||||
self._session: Optional[BacktestSession] = None
|
||||
self._cumulative_reward = 0.0
|
||||
self._day_index = 0
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
@ -96,9 +101,30 @@ class DecisionEnv:
|
||||
self._last_metrics = None
|
||||
self._last_action = None
|
||||
self._last_department_controls = None
|
||||
self._cumulative_reward = 0.0
|
||||
self._day_index = 0
|
||||
|
||||
cfg = replace(self._template_cfg)
|
||||
self._engine = BacktestEngine(cfg)
|
||||
self._engine.weights = weight_map(self._baseline_weights)
|
||||
if self._disable_departments:
|
||||
self._engine.department_manager = None
|
||||
|
||||
self._clear_portfolio_records()
|
||||
|
||||
self._session = self._engine.start_session()
|
||||
return {
|
||||
"episode": float(self._episode),
|
||||
"baseline_return": 0.0,
|
||||
"day_index": 0.0,
|
||||
"date_ord": float(self._template_cfg.start_date.toordinal()),
|
||||
"nav": float(self._session.state.cash),
|
||||
"total_return": 0.0,
|
||||
"max_drawdown": 0.0,
|
||||
"volatility": 0.0,
|
||||
"turnover": 0.0,
|
||||
"sharpe_like": 0.0,
|
||||
"trade_count": 0.0,
|
||||
"risk_count": 0.0,
|
||||
}
|
||||
|
||||
def step(self, action: Sequence[float]) -> Tuple[Dict[str, float], float, bool, Dict[str, object]]:
|
||||
@ -117,55 +143,56 @@ class DecisionEnv:
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
cfg = replace(self._template_cfg)
|
||||
engine = BacktestEngine(cfg)
|
||||
engine = self._engine
|
||||
session = self._session
|
||||
if engine is None or session is None:
|
||||
raise RuntimeError("environment not initialised; call reset() before step()")
|
||||
|
||||
engine.weights = weight_map(weights)
|
||||
if self._disable_departments:
|
||||
applied_controls = {}
|
||||
engine.department_manager = None
|
||||
applied_controls: Dict[str, Dict[str, Any]] = {}
|
||||
else:
|
||||
applied_controls = self._apply_department_controls(engine, department_controls)
|
||||
|
||||
self._clear_portfolio_records()
|
||||
|
||||
try:
|
||||
result = engine.run()
|
||||
records, done = engine.step_session(session)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception("backtest failed under action", extra={**LOG_EXTRA, "error": str(exc)})
|
||||
info = {"error": str(exc)}
|
||||
return {"failure": 1.0}, -1.0, True, info
|
||||
|
||||
records_list = list(records) if records is not None else []
|
||||
|
||||
snapshots, trades_override = self._fetch_portfolio_records()
|
||||
metrics = self._compute_metrics(
|
||||
result,
|
||||
session.result,
|
||||
nav_override=snapshots if snapshots else None,
|
||||
trades_override=trades_override if trades_override else None,
|
||||
)
|
||||
reward = float(self._reward_fn(metrics))
|
||||
|
||||
total_reward = float(self._reward_fn(metrics))
|
||||
reward = total_reward - self._cumulative_reward
|
||||
self._cumulative_reward = total_reward
|
||||
self._last_metrics = metrics
|
||||
|
||||
observation = {
|
||||
"total_return": metrics.total_return,
|
||||
"max_drawdown": metrics.max_drawdown,
|
||||
"volatility": metrics.volatility,
|
||||
"sharpe_like": metrics.sharpe_like,
|
||||
"turnover": metrics.turnover,
|
||||
"turnover_value": metrics.turnover_value,
|
||||
"trade_count": float(metrics.trade_count),
|
||||
"risk_count": float(metrics.risk_count),
|
||||
}
|
||||
observation = self._build_observation(metrics, records_list, done)
|
||||
observation["turnover_value"] = metrics.turnover_value
|
||||
info = {
|
||||
"nav_series": metrics.nav_series,
|
||||
"trades": metrics.trades,
|
||||
"weights": weights,
|
||||
"risk_breakdown": metrics.risk_breakdown,
|
||||
"risk_events": getattr(result, "risk_events", []),
|
||||
"risk_events": getattr(session.result, "risk_events", []),
|
||||
"portfolio_snapshots": snapshots,
|
||||
"portfolio_trades": trades_override,
|
||||
"department_controls": applied_controls,
|
||||
"session_done": done,
|
||||
"raw_records": records_list,
|
||||
}
|
||||
self._last_department_controls = applied_controls
|
||||
return observation, reward, True, info
|
||||
self._day_index += 1
|
||||
return observation, reward, done, info
|
||||
|
||||
def _prepare_actions(
|
||||
self,
|
||||
@ -408,6 +435,51 @@ class DecisionEnv:
|
||||
penalty = 0.5 * metrics.max_drawdown + risk_penalty + turnover_penalty
|
||||
return metrics.total_return - penalty
|
||||
|
||||
def _build_observation(
|
||||
self,
|
||||
metrics: EpisodeMetrics,
|
||||
records: Sequence[Dict[str, Any]] | None,
|
||||
done: bool,
|
||||
) -> Dict[str, float]:
|
||||
observation: Dict[str, float] = {
|
||||
"day_index": float(self._day_index + 1),
|
||||
"total_return": metrics.total_return,
|
||||
"max_drawdown": metrics.max_drawdown,
|
||||
"volatility": metrics.volatility,
|
||||
"sharpe_like": metrics.sharpe_like,
|
||||
"turnover": metrics.turnover,
|
||||
"trade_count": float(metrics.trade_count),
|
||||
"risk_count": float(metrics.risk_count),
|
||||
"done": 1.0 if done else 0.0,
|
||||
}
|
||||
|
||||
latest_snapshot = metrics.nav_series[-1] if metrics.nav_series else None
|
||||
if latest_snapshot:
|
||||
observation["nav"] = float(latest_snapshot.get("nav", 0.0) or 0.0)
|
||||
observation["cash"] = float(latest_snapshot.get("cash", 0.0) or 0.0)
|
||||
observation["market_value"] = float(latest_snapshot.get("market_value", 0.0) or 0.0)
|
||||
trade_date = latest_snapshot.get("trade_date")
|
||||
if isinstance(trade_date, date):
|
||||
observation["date_ord"] = float(trade_date.toordinal())
|
||||
elif isinstance(trade_date, str):
|
||||
try:
|
||||
parsed = date.fromisoformat(trade_date)
|
||||
except ValueError:
|
||||
parsed = None
|
||||
if parsed:
|
||||
observation["date_ord"] = float(parsed.toordinal())
|
||||
if "turnover_ratio" in latest_snapshot and latest_snapshot["turnover_ratio"] is not None:
|
||||
try:
|
||||
observation["turnover_ratio"] = float(latest_snapshot["turnover_ratio"])
|
||||
except (TypeError, ValueError):
|
||||
observation["turnover_ratio"] = 0.0
|
||||
|
||||
# Include a simple proxy for action effect size when available
|
||||
if records:
|
||||
observation["record_count"] = float(len(records))
|
||||
|
||||
return observation
|
||||
|
||||
@property
|
||||
def last_metrics(self) -> Optional[EpisodeMetrics]:
|
||||
return self._last_metrics
|
||||
|
||||
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
|
||||
from app.agents.base import AgentAction, AgentContext
|
||||
from app.agents.departments import DepartmentManager
|
||||
@ -72,6 +72,15 @@ class BacktestResult:
|
||||
risk_events: List[Dict[str, object]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BacktestSession:
|
||||
"""Holds the mutable state for incremental backtest execution."""
|
||||
|
||||
state: PortfolioState
|
||||
result: BacktestResult
|
||||
current_date: date
|
||||
|
||||
|
||||
class BacktestEngine:
|
||||
"""Runs the multi-agent game inside a daily event-driven loop."""
|
||||
|
||||
@ -892,18 +901,48 @@ class BacktestEngine:
|
||||
],
|
||||
)
|
||||
|
||||
def start_session(self) -> BacktestSession:
|
||||
"""Initialise a new incremental backtest session."""
|
||||
|
||||
return BacktestSession(
|
||||
state=PortfolioState(),
|
||||
result=BacktestResult(),
|
||||
current_date=self.cfg.start_date,
|
||||
)
|
||||
|
||||
def step_session(
|
||||
self,
|
||||
session: BacktestSession,
|
||||
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
|
||||
) -> Tuple[Iterable[Dict[str, Any]], bool]:
|
||||
"""Advance the session by a single trade date.
|
||||
|
||||
Returns ``(records, done)`` where ``records`` is the raw output of
|
||||
:meth:`simulate_day` and ``done`` indicates whether the session
|
||||
reached the end date after this step.
|
||||
"""
|
||||
|
||||
if session.current_date > self.cfg.end_date:
|
||||
return [], True
|
||||
|
||||
trade_date = session.current_date
|
||||
records = self.simulate_day(trade_date, session.state, decision_callback)
|
||||
self._apply_portfolio_updates(trade_date, session.state, records, session.result)
|
||||
session.current_date = date.fromordinal(trade_date.toordinal() + 1)
|
||||
done = session.current_date > self.cfg.end_date
|
||||
return records, done
|
||||
|
||||
def run(
|
||||
self,
|
||||
decision_callback: Optional[Callable[[str, date, AgentContext, Decision], None]] = None,
|
||||
) -> BacktestResult:
|
||||
state = PortfolioState()
|
||||
result = BacktestResult()
|
||||
current = self.cfg.start_date
|
||||
while current <= self.cfg.end_date:
|
||||
records = self.simulate_day(current, state, decision_callback)
|
||||
self._apply_portfolio_updates(current, state, records, result)
|
||||
current = date.fromordinal(current.toordinal() + 1)
|
||||
return result
|
||||
session = self.start_session()
|
||||
if session.current_date > self.cfg.end_date:
|
||||
return session.result
|
||||
|
||||
while session.current_date <= self.cfg.end_date:
|
||||
self.step_session(session, decision_callback)
|
||||
return session.result
|
||||
|
||||
|
||||
def run_backtest(
|
||||
|
||||
@ -71,7 +71,14 @@ class EpsilonGreedyBandit:
|
||||
for episode in range(1, self.config.episodes + 1):
|
||||
action = self._select_action()
|
||||
self.env.reset()
|
||||
done = False
|
||||
cumulative_reward = 0.0
|
||||
obs = {}
|
||||
info: Dict[str, Any] = {}
|
||||
while not done:
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
cumulative_reward += reward
|
||||
|
||||
metrics = self.env.last_metrics
|
||||
if metrics is None:
|
||||
raise RuntimeError("DecisionEnv did not populate last_metrics")
|
||||
@ -79,7 +86,7 @@ class EpsilonGreedyBandit:
|
||||
old_estimate = self._value_estimates.get(key, 0.0)
|
||||
count = self._counts.get(key, 0) + 1
|
||||
self._counts[key] = count
|
||||
self._value_estimates[key] = old_estimate + (reward - old_estimate) / count
|
||||
self._value_estimates[key] = old_estimate + (cumulative_reward - old_estimate) / count
|
||||
|
||||
action_payload = self._raw_action_mapping(action)
|
||||
resolved_action = self._resolved_action_mapping(action)
|
||||
@ -93,7 +100,7 @@ class EpsilonGreedyBandit:
|
||||
experiment_id=self.config.experiment_id,
|
||||
strategy=self.config.strategy,
|
||||
action=action_payload,
|
||||
reward=reward,
|
||||
reward=cumulative_reward,
|
||||
metrics=metrics_payload,
|
||||
weights=info.get("weights"),
|
||||
)
|
||||
@ -103,7 +110,7 @@ class EpsilonGreedyBandit:
|
||||
episode_record = BanditEpisode(
|
||||
action=action_payload,
|
||||
resolved_action=resolved_action,
|
||||
reward=reward,
|
||||
reward=cumulative_reward,
|
||||
metrics=metrics,
|
||||
observation=obs,
|
||||
weights=info.get("weights"),
|
||||
@ -113,7 +120,7 @@ class EpsilonGreedyBandit:
|
||||
LOGGER.info(
|
||||
"Bandit episode=%s reward=%.4f action=%s",
|
||||
episode,
|
||||
reward,
|
||||
cumulative_reward,
|
||||
action_payload,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
11
app/rl/__init__.py
Normal file
11
app/rl/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
"""Reinforcement learning utilities for DecisionEnv."""
|
||||
|
||||
from .adapters import DecisionEnvAdapter
|
||||
from .ppo import PPOConfig, PPOTrainer, train_ppo
|
||||
|
||||
__all__ = [
|
||||
"DecisionEnvAdapter",
|
||||
"PPOConfig",
|
||||
"PPOTrainer",
|
||||
"train_ppo",
|
||||
]
|
||||
57
app/rl/adapters.py
Normal file
57
app/rl/adapters.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""Environment adapters bridging DecisionEnv to tensor-friendly interfaces."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, List, Mapping, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.backtest.decision_env import DecisionEnv
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecisionEnvAdapter:
|
||||
"""Wraps :class:`DecisionEnv` to emit numpy arrays for RL algorithms."""
|
||||
|
||||
env: DecisionEnv
|
||||
observation_keys: Sequence[str] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.observation_keys is None:
|
||||
reset_obs = self.env.reset()
|
||||
# Exclude bookkeeping fields not useful for learning policy values
|
||||
exclude = {"episode"}
|
||||
self._keys = [key for key in sorted(reset_obs.keys()) if key not in exclude]
|
||||
self._last_reset_obs = reset_obs
|
||||
else:
|
||||
self._keys = list(self.observation_keys)
|
||||
self._last_reset_obs = None
|
||||
|
||||
@property
|
||||
def action_dim(self) -> int:
|
||||
return self.env.action_dim
|
||||
|
||||
@property
|
||||
def observation_dim(self) -> int:
|
||||
return len(self._keys)
|
||||
|
||||
def reset(self) -> Tuple[np.ndarray, Dict[str, float]]:
|
||||
raw = self.env.reset()
|
||||
self._last_reset_obs = raw
|
||||
return self._to_array(raw), raw
|
||||
|
||||
def step(
|
||||
self, action: Sequence[float]
|
||||
) -> Tuple[np.ndarray, float, bool, Mapping[str, object], Mapping[str, float]]:
|
||||
obs_dict, reward, done, info = self.env.step(action)
|
||||
return self._to_array(obs_dict), reward, done, info, obs_dict
|
||||
|
||||
def _to_array(self, payload: Mapping[str, float]) -> np.ndarray:
|
||||
buffer = np.zeros(len(self._keys), dtype=np.float32)
|
||||
for idx, key in enumerate(self._keys):
|
||||
value = payload.get(key)
|
||||
buffer[idx] = float(value) if value is not None else 0.0
|
||||
return buffer
|
||||
|
||||
def keys(self) -> List[str]:
|
||||
return list(self._keys)
|
||||
265
app/rl/ppo.py
Normal file
265
app/rl/ppo.py
Normal file
@ -0,0 +1,265 @@
|
||||
"""Lightweight PPO trainer tailored for DecisionEnv."""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from .adapters import DecisionEnvAdapter
|
||||
|
||||
|
||||
def _init_layer(layer: nn.Module, std: float = 1.0) -> nn.Module:
|
||||
if isinstance(layer, nn.Linear):
|
||||
nn.init.orthogonal_(layer.weight, gain=std)
|
||||
nn.init.zeros_(layer.bias)
|
||||
return layer
|
||||
|
||||
|
||||
class ActorNetwork(nn.Module):
|
||||
def __init__(self, obs_dim: int, action_dim: int, hidden_sizes: Sequence[int]) -> None:
|
||||
super().__init__()
|
||||
layers: List[nn.Module] = []
|
||||
last_dim = obs_dim
|
||||
for size in hidden_sizes:
|
||||
layers.append(_init_layer(nn.Linear(last_dim, size), std=math.sqrt(2)))
|
||||
layers.append(nn.Tanh())
|
||||
last_dim = size
|
||||
self.body = nn.Sequential(*layers)
|
||||
self.alpha_head = _init_layer(nn.Linear(last_dim, action_dim), std=0.01)
|
||||
self.beta_head = _init_layer(nn.Linear(last_dim, action_dim), std=0.01)
|
||||
|
||||
def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden = self.body(obs)
|
||||
alpha = torch.nn.functional.softplus(self.alpha_head(hidden)) + 1.0
|
||||
beta = torch.nn.functional.softplus(self.beta_head(hidden)) + 1.0
|
||||
return alpha, beta
|
||||
|
||||
|
||||
class CriticNetwork(nn.Module):
|
||||
def __init__(self, obs_dim: int, hidden_sizes: Sequence[int]) -> None:
|
||||
super().__init__()
|
||||
layers: List[nn.Module] = []
|
||||
last_dim = obs_dim
|
||||
for size in hidden_sizes:
|
||||
layers.append(_init_layer(nn.Linear(last_dim, size), std=math.sqrt(2)))
|
||||
layers.append(nn.Tanh())
|
||||
last_dim = size
|
||||
layers.append(_init_layer(nn.Linear(last_dim, 1), std=1.0))
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
||||
return self.model(obs).squeeze(-1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PPOConfig:
|
||||
total_timesteps: int = 4096
|
||||
rollout_steps: int = 256
|
||||
gamma: float = 0.99
|
||||
gae_lambda: float = 0.95
|
||||
clip_range: float = 0.2
|
||||
policy_lr: float = 3e-4
|
||||
value_lr: float = 3e-4
|
||||
epochs: int = 8
|
||||
minibatch_size: int = 128
|
||||
entropy_coef: float = 0.01
|
||||
value_coef: float = 0.5
|
||||
max_grad_norm: float = 0.5
|
||||
hidden_sizes: Sequence[int] = (128, 128)
|
||||
device: str = "cpu"
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingSummary:
|
||||
timesteps: int
|
||||
episode_rewards: List[float] = field(default_factory=list)
|
||||
episode_lengths: List[int] = field(default_factory=list)
|
||||
diagnostics: List[Dict[str, float]] = field(default_factory=list)
|
||||
|
||||
|
||||
class RolloutBuffer:
|
||||
def __init__(self, size: int, obs_dim: int, action_dim: int, device: torch.device) -> None:
|
||||
self.size = size
|
||||
self.device = device
|
||||
self.obs = torch.zeros((size, obs_dim), dtype=torch.float32, device=device)
|
||||
self.actions = torch.zeros((size, action_dim), dtype=torch.float32, device=device)
|
||||
self.log_probs = torch.zeros(size, dtype=torch.float32, device=device)
|
||||
self.rewards = torch.zeros(size, dtype=torch.float32, device=device)
|
||||
self.dones = torch.zeros(size, dtype=torch.float32, device=device)
|
||||
self.values = torch.zeros(size, dtype=torch.float32, device=device)
|
||||
self.advantages = torch.zeros(size, dtype=torch.float32, device=device)
|
||||
self.returns = torch.zeros(size, dtype=torch.float32, device=device)
|
||||
self._pos = 0
|
||||
|
||||
def add(
|
||||
self,
|
||||
obs: torch.Tensor,
|
||||
action: torch.Tensor,
|
||||
log_prob: torch.Tensor,
|
||||
reward: float,
|
||||
done: bool,
|
||||
value: torch.Tensor,
|
||||
) -> None:
|
||||
if self._pos >= self.size:
|
||||
raise RuntimeError("rollout buffer overflow; check rollout_steps")
|
||||
self.obs[self._pos].copy_(obs)
|
||||
self.actions[self._pos].copy_(action)
|
||||
self.log_probs[self._pos] = log_prob
|
||||
self.rewards[self._pos] = reward
|
||||
self.dones[self._pos] = float(done)
|
||||
self.values[self._pos] = value
|
||||
self._pos += 1
|
||||
|
||||
def finish(self, last_value: float, gamma: float, gae_lambda: float) -> None:
|
||||
last_advantage = 0.0
|
||||
for idx in reversed(range(self._pos)):
|
||||
mask = 1.0 - float(self.dones[idx])
|
||||
value = float(self.values[idx])
|
||||
delta = float(self.rewards[idx]) + gamma * last_value * mask - value
|
||||
last_advantage = delta + gamma * gae_lambda * mask * last_advantage
|
||||
self.advantages[idx] = last_advantage
|
||||
self.returns[idx] = last_advantage + value
|
||||
last_value = value
|
||||
|
||||
if self._pos:
|
||||
adv_view = self.advantages[: self._pos]
|
||||
adv_mean = adv_view.mean()
|
||||
adv_std = adv_view.std(unbiased=False) + 1e-8
|
||||
adv_view.sub_(adv_mean).div_(adv_std)
|
||||
|
||||
def get_minibatches(self, batch_size: int) -> Iterable[Tuple[torch.Tensor, ...]]:
|
||||
if self._pos == 0:
|
||||
return []
|
||||
indices = torch.randperm(self._pos, device=self.device)
|
||||
for start in range(0, self._pos, batch_size):
|
||||
end = min(start + batch_size, self._pos)
|
||||
batch_idx = indices[start:end]
|
||||
yield (
|
||||
self.obs[batch_idx],
|
||||
self.actions[batch_idx],
|
||||
self.log_probs[batch_idx],
|
||||
self.advantages[batch_idx],
|
||||
self.returns[batch_idx],
|
||||
self.values[batch_idx],
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
self._pos = 0
|
||||
|
||||
|
||||
class PPOTrainer:
|
||||
def __init__(self, adapter: DecisionEnvAdapter, config: PPOConfig) -> None:
|
||||
self.adapter = adapter
|
||||
self.config = config
|
||||
device = torch.device(config.device)
|
||||
obs_dim = adapter.observation_dim
|
||||
action_dim = adapter.action_dim
|
||||
self.actor = ActorNetwork(obs_dim, action_dim, config.hidden_sizes).to(device)
|
||||
self.critic = CriticNetwork(obs_dim, config.hidden_sizes).to(device)
|
||||
self.policy_optimizer = torch.optim.Adam(self.actor.parameters(), lr=config.policy_lr)
|
||||
self.value_optimizer = torch.optim.Adam(self.critic.parameters(), lr=config.value_lr)
|
||||
self.device = device
|
||||
if config.seed is not None:
|
||||
torch.manual_seed(config.seed)
|
||||
np.random.seed(config.seed)
|
||||
|
||||
def train(self) -> TrainingSummary:
|
||||
cfg = self.config
|
||||
obs_array, _ = self.adapter.reset()
|
||||
obs = torch.from_numpy(obs_array).to(self.device)
|
||||
rollout = RolloutBuffer(cfg.rollout_steps, self.adapter.observation_dim, self.adapter.action_dim, self.device)
|
||||
timesteps = 0
|
||||
episode_rewards: List[float] = []
|
||||
episode_lengths: List[int] = []
|
||||
diagnostics: List[Dict[str, float]] = []
|
||||
current_return = 0.0
|
||||
current_length = 0
|
||||
|
||||
while timesteps < cfg.total_timesteps:
|
||||
rollout.reset()
|
||||
steps_to_collect = min(cfg.rollout_steps, cfg.total_timesteps - timesteps)
|
||||
for _ in range(steps_to_collect):
|
||||
with torch.no_grad():
|
||||
alpha, beta = self.actor(obs.unsqueeze(0))
|
||||
dist = Beta(alpha, beta)
|
||||
action = dist.rsample().squeeze(0)
|
||||
log_prob = dist.log_prob(action).sum()
|
||||
value = self.critic(obs.unsqueeze(0)).squeeze(0)
|
||||
action_np = action.cpu().numpy()
|
||||
next_obs_array, reward, done, info, _ = self.adapter.step(action_np)
|
||||
next_obs = torch.from_numpy(next_obs_array).to(self.device)
|
||||
|
||||
rollout.add(obs, action, log_prob, reward, done, value)
|
||||
timesteps += 1
|
||||
current_return += reward
|
||||
current_length += 1
|
||||
|
||||
if done:
|
||||
episode_rewards.append(current_return)
|
||||
episode_lengths.append(current_length)
|
||||
current_return = 0.0
|
||||
current_length = 0
|
||||
next_obs_array, _ = self.adapter.reset()
|
||||
next_obs = torch.from_numpy(next_obs_array).to(self.device)
|
||||
|
||||
obs = next_obs
|
||||
|
||||
if timesteps >= cfg.total_timesteps or rollout._pos >= steps_to_collect:
|
||||
break
|
||||
|
||||
with torch.no_grad():
|
||||
next_value = self.critic(obs.unsqueeze(0)).squeeze(0).item()
|
||||
rollout.finish(last_value=next_value, gamma=cfg.gamma, gae_lambda=cfg.gae_lambda)
|
||||
|
||||
for _ in range(cfg.epochs):
|
||||
for (mb_obs, mb_actions, mb_log_probs, mb_adv, mb_returns, _) in rollout.get_minibatches(
|
||||
cfg.minibatch_size
|
||||
):
|
||||
alpha, beta = self.actor(mb_obs)
|
||||
dist = Beta(alpha, beta)
|
||||
new_log_probs = dist.log_prob(mb_actions).sum(dim=-1)
|
||||
entropy = dist.entropy().sum(dim=-1)
|
||||
ratios = torch.exp(new_log_probs - mb_log_probs)
|
||||
surrogate1 = ratios * mb_adv
|
||||
surrogate2 = torch.clamp(ratios, 1.0 - cfg.clip_range, 1.0 + cfg.clip_range) * mb_adv
|
||||
policy_loss = -torch.min(surrogate1, surrogate2).mean() - cfg.entropy_coef * entropy.mean()
|
||||
|
||||
self.policy_optimizer.zero_grad()
|
||||
policy_loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.actor.parameters(), cfg.max_grad_norm)
|
||||
self.policy_optimizer.step()
|
||||
|
||||
values = self.critic(mb_obs)
|
||||
value_loss = torch.nn.functional.mse_loss(values, mb_returns)
|
||||
self.value_optimizer.zero_grad()
|
||||
value_loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.critic.parameters(), cfg.max_grad_norm)
|
||||
self.value_optimizer.step()
|
||||
|
||||
diagnostics.append(
|
||||
{
|
||||
"policy_loss": float(policy_loss.detach().cpu()),
|
||||
"value_loss": float(value_loss.detach().cpu()),
|
||||
"entropy": float(entropy.mean().detach().cpu()),
|
||||
}
|
||||
)
|
||||
|
||||
return TrainingSummary(
|
||||
timesteps=timesteps,
|
||||
episode_rewards=episode_rewards,
|
||||
episode_lengths=episode_lengths,
|
||||
diagnostics=diagnostics,
|
||||
)
|
||||
|
||||
|
||||
def train_ppo(adapter: DecisionEnvAdapter, config: PPOConfig) -> TrainingSummary:
|
||||
"""Convenience helper to run PPO training with defaults."""
|
||||
|
||||
trainer = PPOTrainer(adapter, config)
|
||||
return trainer.train()
|
||||
11
docs/TODO.md
11
docs/TODO.md
@ -20,10 +20,21 @@
|
||||
## 3. 决策优化与强化学习
|
||||
- ✅ 扩展 `DecisionEnv` 的动作空间(提示版本、部门温度、function 调用策略等),不仅限于代理权重调节。
|
||||
- 引入 Bandit / 贝叶斯优化或 RL 算法探索动作空间,并将 `portfolio_snapshots`、`portfolio_trades` 指标纳入奖励约束。
|
||||
- 将 `DecisionEnv` 改造为多步 episode,逐日输出状态(行情特征、持仓、风险事件)与动作,充分利用历史序列训练强化学习策略。
|
||||
- ✅ 基于多步环境接入 PPO / SAC 等连续动作算法,结合收益、回撤、成交成本设定奖励与约束,提升收益最大化的稳定性。
|
||||
- 在整段回测层面引入并行贝叶斯优化(TPE/BOHB)或其他全局搜索,为强化学习提供高收益初始权重与参数候选。
|
||||
- 建立离线验证与滚动前向测试流程,对新策略做回测 vs. 实盘对照,防止收益最大化策略过拟合历史数据。
|
||||
- 构建实时持仓/成交数据写入链路,使线上监控与离线调参共用同一数据源。
|
||||
- 借鉴 TradingAgents-CN 的做法:拆分环境与策略、提供训练脚本/配置,并输出丰富的评估指标(如 Sharpe、Sortino、基准对比)。
|
||||
- 完善 `BacktestEngine` 的成交撮合、风险阈值与指标输出,让回测信号直接对接执行端,形成无人值守的自动闭环。
|
||||
|
||||
### 3.1 实施步骤(建议顺序)
|
||||
1. 环境重构:扩展 `DecisionEnv` 支持逐日状态/动作/奖励,完善 `BacktestEngine` 的状态保存与恢复接口,并补充必要的数据库读写钩子。
|
||||
2. 训练基线:实现基于多步环境的 PPO(或 SAC)训练脚本,定义网络结构、奖励项(收益/回撤/成交成本)和超参,先在小规模标的上验证收敛。
|
||||
3. 全局搜索:在整段回测模式下并行运行 TPE/BOHB 等贝叶斯优化,产出高收益参数作为 RL 的初始化权重或候选策略。
|
||||
4. 验证闭环:搭建滚动前向测试流水线,自动记录训练策略的回测表现与准实时对照,接入监控面板并输出风险/收益指标。
|
||||
5. 上线准备:结合实时持仓/成交链路,完善回滚与安全阈值机制,准备 A/B 或影子跟投实验,确认收益最大化策略的稳健性。
|
||||
|
||||
## 4. 测试与验证
|
||||
- 补充部门上下文构造、多模型调用、回测指标生成等核心路径的单元 / 集成测试。
|
||||
- 建立决策流程的回归测试用例,确保提示模板或配置调整后行为可复现。
|
||||
|
||||
@ -8,3 +8,4 @@ pytest>=7.0
|
||||
feedparser>=6.0
|
||||
arch>=6.1.0
|
||||
scipy>=1.11.0
|
||||
torch>=2.3.0
|
||||
|
||||
133
scripts/train_ppo.py
Normal file
133
scripts/train_ppo.py
Normal file
@ -0,0 +1,133 @@
|
||||
"""Command-line entrypoint for PPO training on DecisionEnv."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.agents.registry import default_agents
|
||||
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
||||
from app.backtest.engine import BtConfig
|
||||
from app.rl import DecisionEnvAdapter, PPOConfig, train_ppo
|
||||
from app.ui.shared import default_backtest_range
|
||||
from app.utils.config import get_config
|
||||
|
||||
|
||||
def _parse_universe(raw: str) -> List[str]:
|
||||
return [item.strip() for item in raw.split(",") if item.strip()]
|
||||
|
||||
|
||||
def build_env(args: argparse.Namespace) -> DecisionEnvAdapter:
|
||||
app_cfg = get_config()
|
||||
start = datetime.strptime(args.start_date, "%Y-%m-%d").date()
|
||||
end = datetime.strptime(args.end_date, "%Y-%m-%d").date()
|
||||
universe = _parse_universe(args.universe)
|
||||
if not universe:
|
||||
raise ValueError("universe must contain at least one ts_code")
|
||||
|
||||
agents = default_agents()
|
||||
baseline_weights = app_cfg.agent_weights.as_dict()
|
||||
for agent in agents:
|
||||
baseline_weights.setdefault(agent.name, 1.0)
|
||||
|
||||
specs: List[ParameterSpec] = []
|
||||
for name in sorted(baseline_weights):
|
||||
specs.append(
|
||||
ParameterSpec(
|
||||
name=f"weight_{name}",
|
||||
target=f"agent_weights.{name}",
|
||||
minimum=args.weight_min,
|
||||
maximum=args.weight_max,
|
||||
)
|
||||
)
|
||||
|
||||
bt_cfg = BtConfig(
|
||||
id=args.experiment_id,
|
||||
name=f"PPO-{args.experiment_id}",
|
||||
start_date=start,
|
||||
end_date=end,
|
||||
universe=universe,
|
||||
params={
|
||||
"target": args.target,
|
||||
"stop": args.stop,
|
||||
"hold_days": args.hold_days,
|
||||
},
|
||||
method=app_cfg.decision_method,
|
||||
)
|
||||
env = DecisionEnv(
|
||||
bt_config=bt_cfg,
|
||||
parameter_specs=specs,
|
||||
baseline_weights=baseline_weights,
|
||||
disable_departments=args.disable_departments,
|
||||
)
|
||||
return DecisionEnvAdapter(env)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Train PPO policy on DecisionEnv")
|
||||
default_start, default_end = default_backtest_range(window_days=60)
|
||||
parser.add_argument("--start-date", default=str(default_start))
|
||||
parser.add_argument("--end-date", default=str(default_end))
|
||||
parser.add_argument("--universe", default="000001.SZ")
|
||||
parser.add_argument("--experiment-id", default=f"ppo_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
||||
parser.add_argument("--hold-days", type=int, default=10)
|
||||
parser.add_argument("--target", type=float, default=0.035)
|
||||
parser.add_argument("--stop", type=float, default=-0.015)
|
||||
parser.add_argument("--total-timesteps", type=int, default=4096)
|
||||
parser.add_argument("--rollout-steps", type=int, default=256)
|
||||
parser.add_argument("--epochs", type=int, default=8)
|
||||
parser.add_argument("--minibatch-size", type=int, default=128)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--gae-lambda", type=float, default=0.95)
|
||||
parser.add_argument("--clip-range", type=float, default=0.2)
|
||||
parser.add_argument("--policy-lr", type=float, default=3e-4)
|
||||
parser.add_argument("--value-lr", type=float, default=3e-4)
|
||||
parser.add_argument("--entropy-coef", type=float, default=0.01)
|
||||
parser.add_argument("--value-coef", type=float, default=0.5)
|
||||
parser.add_argument("--max-grad-norm", type=float, default=0.5)
|
||||
parser.add_argument("--hidden-sizes", default="128,128")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--weight-min", type=float, default=0.0)
|
||||
parser.add_argument("--weight-max", type=float, default=1.5)
|
||||
parser.add_argument("--disable-departments", action="store_true")
|
||||
parser.add_argument("--output", type=Path, default=Path("ppo_training_summary.json"))
|
||||
|
||||
args = parser.parse_args()
|
||||
hidden_sizes = tuple(int(x) for x in args.hidden_sizes.split(",") if x.strip())
|
||||
adapter = build_env(args)
|
||||
|
||||
config = PPOConfig(
|
||||
total_timesteps=args.total_timesteps,
|
||||
rollout_steps=args.rollout_steps,
|
||||
gamma=args.gamma,
|
||||
gae_lambda=args.gae_lambda,
|
||||
clip_range=args.clip_range,
|
||||
policy_lr=args.policy_lr,
|
||||
value_lr=args.value_lr,
|
||||
epochs=args.epochs,
|
||||
minibatch_size=args.minibatch_size,
|
||||
entropy_coef=args.entropy_coef,
|
||||
value_coef=args.value_coef,
|
||||
max_grad_norm=args.max_grad_norm,
|
||||
hidden_sizes=hidden_sizes,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
summary = train_ppo(adapter, config)
|
||||
payload = {
|
||||
"timesteps": summary.timesteps,
|
||||
"episode_rewards": summary.episode_rewards,
|
||||
"episode_lengths": summary.episode_lengths,
|
||||
"diagnostics_tail": summary.diagnostics[-10:],
|
||||
"observation_keys": adapter.keys(),
|
||||
}
|
||||
args.output.write_text(json.dumps(payload, indent=2, ensure_ascii=False))
|
||||
print(f"Training finished. Summary written to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -6,7 +6,7 @@ from datetime import date
|
||||
import pytest
|
||||
|
||||
from app.backtest.decision_env import DecisionEnv, EpisodeMetrics, ParameterSpec
|
||||
from app.backtest.engine import BacktestResult, BtConfig
|
||||
from app.backtest.engine import BacktestResult, BacktestSession, BtConfig, PortfolioState
|
||||
from app.utils.config import DepartmentSettings, LLMConfig, LLMEndpoint
|
||||
|
||||
|
||||
@ -60,11 +60,19 @@ class _StubEngine:
|
||||
self.department_manager = _StubManager()
|
||||
_StubEngine.last_instance = self
|
||||
|
||||
def run(self) -> BacktestResult:
|
||||
result = BacktestResult()
|
||||
result.nav_series = [
|
||||
{
|
||||
"trade_date": "2025-01-10",
|
||||
def start_session(self) -> BacktestSession:
|
||||
return BacktestSession(
|
||||
state=PortfolioState(),
|
||||
result=BacktestResult(),
|
||||
current_date=self.cfg.start_date,
|
||||
)
|
||||
|
||||
def step_session(self, session: BacktestSession, *_args, **_kwargs):
|
||||
if session.current_date > self.cfg.end_date:
|
||||
return [], True
|
||||
|
||||
payload_nav = {
|
||||
"trade_date": session.current_date.isoformat(),
|
||||
"nav": 102.0,
|
||||
"cash": 50.0,
|
||||
"market_value": 52.0,
|
||||
@ -73,10 +81,8 @@ class _StubEngine:
|
||||
"turnover": 20000.0,
|
||||
"turnover_ratio": 0.2,
|
||||
}
|
||||
]
|
||||
result.trades = [
|
||||
{
|
||||
"trade_date": "2025-01-10",
|
||||
payload_trade = {
|
||||
"trade_date": session.current_date.isoformat(),
|
||||
"ts_code": "000001.SZ",
|
||||
"action": "buy",
|
||||
"quantity": 100.0,
|
||||
@ -84,18 +90,20 @@ class _StubEngine:
|
||||
"value": 10000.0,
|
||||
"fee": 5.0,
|
||||
}
|
||||
]
|
||||
result.risk_events = [
|
||||
{
|
||||
"trade_date": "2025-01-10",
|
||||
payload_risk = {
|
||||
"trade_date": session.current_date.isoformat(),
|
||||
"ts_code": "000002.SZ",
|
||||
"reason": "limit_up",
|
||||
"action": "buy_l",
|
||||
"confidence": 0.7,
|
||||
"target_weight": 0.2,
|
||||
}
|
||||
]
|
||||
return result
|
||||
session.result.nav_series.append(payload_nav)
|
||||
session.result.trades.append(payload_trade)
|
||||
session.result.risk_events.append(payload_risk)
|
||||
session.current_date = date.fromordinal(session.current_date.toordinal() + 1)
|
||||
done = session.current_date > self.cfg.end_date
|
||||
return [payload_nav], done
|
||||
|
||||
|
||||
_StubEngine.last_instance: _StubEngine | None = None
|
||||
@ -117,6 +125,7 @@ def test_decision_env_returns_risk_metrics(monkeypatch):
|
||||
monkeypatch.setattr(DecisionEnv, "_clear_portfolio_records", lambda self: None)
|
||||
monkeypatch.setattr(DecisionEnv, "_fetch_portfolio_records", lambda self: ([], []))
|
||||
|
||||
env.reset()
|
||||
obs, reward, done, info = env.step([0.8])
|
||||
|
||||
assert done is True
|
||||
@ -187,6 +196,7 @@ def test_decision_env_department_controls(monkeypatch):
|
||||
monkeypatch.setattr(DecisionEnv, "_clear_portfolio_records", lambda self: None)
|
||||
monkeypatch.setattr(DecisionEnv, "_fetch_portfolio_records", lambda self: ([], []))
|
||||
|
||||
env.reset()
|
||||
obs, reward, done, info = env.step([0.3, 1.0, 0.75, 0.0, 1.0])
|
||||
|
||||
assert done is True
|
||||
@ -198,14 +208,33 @@ def test_decision_env_department_controls(monkeypatch):
|
||||
assert momentum_ctrl["prompt"] == "aggressive"
|
||||
assert momentum_ctrl["temperature"] == pytest.approx(0.7, abs=1e-6)
|
||||
assert momentum_ctrl["tool_choice"] == "none"
|
||||
assert momentum_ctrl["max_rounds"] == 5
|
||||
|
||||
assert env.last_department_controls == controls
|
||||
|
||||
engine = _StubEngine.last_instance
|
||||
assert engine is not None
|
||||
agent = engine.department_manager.agents["momentum"]
|
||||
assert agent.settings.prompt == "aggressive"
|
||||
assert agent.settings.llm.primary.temperature == pytest.approx(0.7, abs=1e-6)
|
||||
assert agent.tool_choice == "none"
|
||||
assert agent.max_rounds == 5
|
||||
def test_decision_env_multistep_session(monkeypatch):
|
||||
cfg = BtConfig(
|
||||
id="stub",
|
||||
name="stub",
|
||||
start_date=date(2025, 1, 10),
|
||||
end_date=date(2025, 1, 12),
|
||||
universe=["000001.SZ"],
|
||||
params={},
|
||||
)
|
||||
specs = [ParameterSpec(name="w_mom", target="agent_weights.A_mom", minimum=0.0, maximum=1.0)]
|
||||
env = DecisionEnv(bt_config=cfg, parameter_specs=specs, baseline_weights={"A_mom": 0.5})
|
||||
|
||||
monkeypatch.setattr("app.backtest.decision_env.BacktestEngine", _StubEngine)
|
||||
monkeypatch.setattr(DecisionEnv, "_clear_portfolio_records", lambda self: None)
|
||||
monkeypatch.setattr(DecisionEnv, "_fetch_portfolio_records", lambda self: ([], []))
|
||||
|
||||
env.reset()
|
||||
obs, reward, done, info = env.step([0.6])
|
||||
assert done is False
|
||||
assert obs["day_index"] == pytest.approx(1.0)
|
||||
|
||||
obs2, reward2, done2, _ = env.step([0.6])
|
||||
assert done2 is False
|
||||
assert obs2["day_index"] == pytest.approx(2.0)
|
||||
|
||||
obs3, reward3, done3, _ = env.step([0.6])
|
||||
assert done3 is True
|
||||
assert obs3["day_index"] == pytest.approx(3.0)
|
||||
|
||||
72
tests/test_ppo_trainer.py
Normal file
72
tests/test_ppo_trainer.py
Normal file
@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
from app.rl.adapters import DecisionEnvAdapter
|
||||
from app.rl.ppo import PPOConfig, train_ppo
|
||||
|
||||
|
||||
class _DummyDecisionEnv:
|
||||
action_dim = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._step = 0
|
||||
self._episode = 0
|
||||
|
||||
def reset(self):
|
||||
self._step = 0
|
||||
self._episode += 1
|
||||
return {
|
||||
"day_index": 0.0,
|
||||
"total_return": 0.0,
|
||||
"max_drawdown": 0.0,
|
||||
"volatility": 0.0,
|
||||
"turnover": 0.0,
|
||||
"sharpe_like": 0.0,
|
||||
"trade_count": 0.0,
|
||||
"risk_count": 0.0,
|
||||
"nav": 1.0,
|
||||
"cash": 1.0,
|
||||
"market_value": 0.0,
|
||||
"done": 0.0,
|
||||
}
|
||||
|
||||
def step(self, action):
|
||||
value = float(action[0])
|
||||
reward = 1.0 - abs(value - 0.8)
|
||||
self._step += 1
|
||||
done = self._step >= 3
|
||||
obs = {
|
||||
"day_index": float(self._step),
|
||||
"total_return": reward,
|
||||
"max_drawdown": 0.1,
|
||||
"volatility": 0.05,
|
||||
"turnover": 0.1,
|
||||
"sharpe_like": reward / 0.05,
|
||||
"trade_count": float(self._step),
|
||||
"risk_count": 0.0,
|
||||
"nav": 1.0 + 0.01 * self._step,
|
||||
"cash": 1.0,
|
||||
"market_value": 0.0,
|
||||
"done": 1.0 if done else 0.0,
|
||||
}
|
||||
info = {}
|
||||
return obs, reward, done, info
|
||||
|
||||
|
||||
def test_train_ppo_completes_with_dummy_env():
|
||||
adapter = DecisionEnvAdapter(_DummyDecisionEnv())
|
||||
config = PPOConfig(
|
||||
total_timesteps=64,
|
||||
rollout_steps=16,
|
||||
epochs=2,
|
||||
minibatch_size=8,
|
||||
hidden_sizes=(32, 32),
|
||||
seed=123,
|
||||
)
|
||||
summary = train_ppo(adapter, config)
|
||||
|
||||
assert summary.timesteps == config.total_timesteps
|
||||
assert summary.episode_rewards
|
||||
assert not math.isnan(summary.episode_rewards[-1])
|
||||
assert summary.diagnostics
|
||||
Loading…
Reference in New Issue
Block a user