134 lines
5.0 KiB
Python
134 lines
5.0 KiB
Python
"""Command-line entrypoint for PPO training on DecisionEnv."""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
|
|
from app.agents.registry import default_agents
|
|
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
|
from app.backtest.engine import BtConfig
|
|
from app.rl import DecisionEnvAdapter, PPOConfig, train_ppo
|
|
from app.ui.shared import default_backtest_range
|
|
from app.utils.config import get_config
|
|
|
|
|
|
def _parse_universe(raw: str) -> List[str]:
|
|
return [item.strip() for item in raw.split(",") if item.strip()]
|
|
|
|
|
|
def build_env(args: argparse.Namespace) -> DecisionEnvAdapter:
|
|
app_cfg = get_config()
|
|
start = datetime.strptime(args.start_date, "%Y-%m-%d").date()
|
|
end = datetime.strptime(args.end_date, "%Y-%m-%d").date()
|
|
universe = _parse_universe(args.universe)
|
|
if not universe:
|
|
raise ValueError("universe must contain at least one ts_code")
|
|
|
|
agents = default_agents()
|
|
baseline_weights = app_cfg.agent_weights.as_dict()
|
|
for agent in agents:
|
|
baseline_weights.setdefault(agent.name, 1.0)
|
|
|
|
specs: List[ParameterSpec] = []
|
|
for name in sorted(baseline_weights):
|
|
specs.append(
|
|
ParameterSpec(
|
|
name=f"weight_{name}",
|
|
target=f"agent_weights.{name}",
|
|
minimum=args.weight_min,
|
|
maximum=args.weight_max,
|
|
)
|
|
)
|
|
|
|
bt_cfg = BtConfig(
|
|
id=args.experiment_id,
|
|
name=f"PPO-{args.experiment_id}",
|
|
start_date=start,
|
|
end_date=end,
|
|
universe=universe,
|
|
params={
|
|
"target": args.target,
|
|
"stop": args.stop,
|
|
"hold_days": args.hold_days,
|
|
},
|
|
method=app_cfg.decision_method,
|
|
)
|
|
env = DecisionEnv(
|
|
bt_config=bt_cfg,
|
|
parameter_specs=specs,
|
|
baseline_weights=baseline_weights,
|
|
disable_departments=args.disable_departments,
|
|
)
|
|
return DecisionEnvAdapter(env)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Train PPO policy on DecisionEnv")
|
|
default_start, default_end = default_backtest_range(window_days=60)
|
|
parser.add_argument("--start-date", default=str(default_start))
|
|
parser.add_argument("--end-date", default=str(default_end))
|
|
parser.add_argument("--universe", default="000001.SZ")
|
|
parser.add_argument("--experiment-id", default=f"ppo_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
|
|
parser.add_argument("--hold-days", type=int, default=10)
|
|
parser.add_argument("--target", type=float, default=0.035)
|
|
parser.add_argument("--stop", type=float, default=-0.015)
|
|
parser.add_argument("--total-timesteps", type=int, default=4096)
|
|
parser.add_argument("--rollout-steps", type=int, default=256)
|
|
parser.add_argument("--epochs", type=int, default=8)
|
|
parser.add_argument("--minibatch-size", type=int, default=128)
|
|
parser.add_argument("--gamma", type=float, default=0.99)
|
|
parser.add_argument("--gae-lambda", type=float, default=0.95)
|
|
parser.add_argument("--clip-range", type=float, default=0.2)
|
|
parser.add_argument("--policy-lr", type=float, default=3e-4)
|
|
parser.add_argument("--value-lr", type=float, default=3e-4)
|
|
parser.add_argument("--entropy-coef", type=float, default=0.01)
|
|
parser.add_argument("--value-coef", type=float, default=0.5)
|
|
parser.add_argument("--max-grad-norm", type=float, default=0.5)
|
|
parser.add_argument("--hidden-sizes", default="128,128")
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("--weight-min", type=float, default=0.0)
|
|
parser.add_argument("--weight-max", type=float, default=1.5)
|
|
parser.add_argument("--disable-departments", action="store_true")
|
|
parser.add_argument("--output", type=Path, default=Path("ppo_training_summary.json"))
|
|
|
|
args = parser.parse_args()
|
|
hidden_sizes = tuple(int(x) for x in args.hidden_sizes.split(",") if x.strip())
|
|
adapter = build_env(args)
|
|
|
|
config = PPOConfig(
|
|
total_timesteps=args.total_timesteps,
|
|
rollout_steps=args.rollout_steps,
|
|
gamma=args.gamma,
|
|
gae_lambda=args.gae_lambda,
|
|
clip_range=args.clip_range,
|
|
policy_lr=args.policy_lr,
|
|
value_lr=args.value_lr,
|
|
epochs=args.epochs,
|
|
minibatch_size=args.minibatch_size,
|
|
entropy_coef=args.entropy_coef,
|
|
value_coef=args.value_coef,
|
|
max_grad_norm=args.max_grad_norm,
|
|
hidden_sizes=hidden_sizes,
|
|
seed=args.seed,
|
|
)
|
|
|
|
summary = train_ppo(adapter, config)
|
|
payload = {
|
|
"timesteps": summary.timesteps,
|
|
"episode_rewards": summary.episode_rewards,
|
|
"episode_lengths": summary.episode_lengths,
|
|
"diagnostics_tail": summary.diagnostics[-10:],
|
|
"observation_keys": adapter.keys(),
|
|
}
|
|
args.output.write_text(json.dumps(payload, indent=2, ensure_ascii=False))
|
|
print(f"Training finished. Summary written to {args.output}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|