llm-quant/scripts/train_ppo.py
2025-10-06 21:51:02 +08:00

134 lines
5.0 KiB
Python

"""Command-line entrypoint for PPO training on DecisionEnv."""
from __future__ import annotations
import argparse
import json
from datetime import datetime
from pathlib import Path
from typing import List
import numpy as np
from app.agents.registry import default_agents
from app.backtest.decision_env import DecisionEnv, ParameterSpec
from app.backtest.engine import BtConfig
from app.rl import DecisionEnvAdapter, PPOConfig, train_ppo
from app.ui.shared import default_backtest_range
from app.utils.config import get_config
def _parse_universe(raw: str) -> List[str]:
return [item.strip() for item in raw.split(",") if item.strip()]
def build_env(args: argparse.Namespace) -> DecisionEnvAdapter:
app_cfg = get_config()
start = datetime.strptime(args.start_date, "%Y-%m-%d").date()
end = datetime.strptime(args.end_date, "%Y-%m-%d").date()
universe = _parse_universe(args.universe)
if not universe:
raise ValueError("universe must contain at least one ts_code")
agents = default_agents()
baseline_weights = app_cfg.agent_weights.as_dict()
for agent in agents:
baseline_weights.setdefault(agent.name, 1.0)
specs: List[ParameterSpec] = []
for name in sorted(baseline_weights):
specs.append(
ParameterSpec(
name=f"weight_{name}",
target=f"agent_weights.{name}",
minimum=args.weight_min,
maximum=args.weight_max,
)
)
bt_cfg = BtConfig(
id=args.experiment_id,
name=f"PPO-{args.experiment_id}",
start_date=start,
end_date=end,
universe=universe,
params={
"target": args.target,
"stop": args.stop,
"hold_days": args.hold_days,
},
method=app_cfg.decision_method,
)
env = DecisionEnv(
bt_config=bt_cfg,
parameter_specs=specs,
baseline_weights=baseline_weights,
disable_departments=args.disable_departments,
)
return DecisionEnvAdapter(env)
def main() -> None:
parser = argparse.ArgumentParser(description="Train PPO policy on DecisionEnv")
default_start, default_end = default_backtest_range(window_days=60)
parser.add_argument("--start-date", default=str(default_start))
parser.add_argument("--end-date", default=str(default_end))
parser.add_argument("--universe", default="000001.SZ")
parser.add_argument("--experiment-id", default=f"ppo_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
parser.add_argument("--hold-days", type=int, default=10)
parser.add_argument("--target", type=float, default=0.035)
parser.add_argument("--stop", type=float, default=-0.015)
parser.add_argument("--total-timesteps", type=int, default=4096)
parser.add_argument("--rollout-steps", type=int, default=256)
parser.add_argument("--epochs", type=int, default=8)
parser.add_argument("--minibatch-size", type=int, default=128)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument("--clip-range", type=float, default=0.2)
parser.add_argument("--policy-lr", type=float, default=3e-4)
parser.add_argument("--value-lr", type=float, default=3e-4)
parser.add_argument("--entropy-coef", type=float, default=0.01)
parser.add_argument("--value-coef", type=float, default=0.5)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--hidden-sizes", default="128,128")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--weight-min", type=float, default=0.0)
parser.add_argument("--weight-max", type=float, default=1.5)
parser.add_argument("--disable-departments", action="store_true")
parser.add_argument("--output", type=Path, default=Path("ppo_training_summary.json"))
args = parser.parse_args()
hidden_sizes = tuple(int(x) for x in args.hidden_sizes.split(",") if x.strip())
adapter = build_env(args)
config = PPOConfig(
total_timesteps=args.total_timesteps,
rollout_steps=args.rollout_steps,
gamma=args.gamma,
gae_lambda=args.gae_lambda,
clip_range=args.clip_range,
policy_lr=args.policy_lr,
value_lr=args.value_lr,
epochs=args.epochs,
minibatch_size=args.minibatch_size,
entropy_coef=args.entropy_coef,
value_coef=args.value_coef,
max_grad_norm=args.max_grad_norm,
hidden_sizes=hidden_sizes,
seed=args.seed,
)
summary = train_ppo(adapter, config)
payload = {
"timesteps": summary.timesteps,
"episode_rewards": summary.episode_rewards,
"episode_lengths": summary.episode_lengths,
"diagnostics_tail": summary.diagnostics[-10:],
"observation_keys": adapter.keys(),
}
args.output.write_text(json.dumps(payload, indent=2, ensure_ascii=False))
print(f"Training finished. Summary written to {args.output}")
if __name__ == "__main__":
main()