From bc60a6115f3406aefc189e2e7a15db1b1438ba0c Mon Sep 17 00:00:00 2001 From: sam Date: Mon, 6 Oct 2025 22:00:24 +0800 Subject: [PATCH] add PPO training UI and torch optional dependency handling --- app/rl/__init__.py | 46 +++++++- app/ui/views/backtest.py | 249 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 294 insertions(+), 1 deletion(-) diff --git a/app/rl/__init__.py b/app/rl/__init__.py index a0ae04c..1ee51af 100644 --- a/app/rl/__init__.py +++ b/app/rl/__init__.py @@ -1,11 +1,55 @@ """Reinforcement learning utilities for DecisionEnv.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Sequence + from .adapters import DecisionEnvAdapter -from .ppo import PPOConfig, PPOTrainer, train_ppo + +TORCH_AVAILABLE = True + +try: # pragma: no cover - exercised via integration + from .ppo import PPOConfig, PPOTrainer, train_ppo +except ModuleNotFoundError as exc: # pragma: no cover - optional dependency guard + if exc.name != "torch": + raise + TORCH_AVAILABLE = False + + @dataclass + class PPOConfig: + """Placeholder PPOConfig used when torch is unavailable.""" + + total_timesteps: int = 0 + rollout_steps: int = 0 + gamma: float = 0.0 + gae_lambda: float = 0.0 + clip_range: float = 0.0 + policy_lr: float = 0.0 + value_lr: float = 0.0 + epochs: int = 0 + minibatch_size: int = 0 + entropy_coef: float = 0.0 + value_coef: float = 0.0 + max_grad_norm: float = 0.0 + hidden_sizes: Sequence[int] = field(default_factory=tuple) + seed: Optional[int] = None + + class PPOTrainer: # pragma: no cover - simply raises when used + def __init__(self, *args, **kwargs) -> None: + raise ModuleNotFoundError( + "torch is required for PPO training. Please install torch before using this feature." + ) + + def train_ppo(*_args, **_kwargs): # pragma: no cover - simply raises when used + raise ModuleNotFoundError( + "torch is required for PPO training. Please install torch before using this feature." + ) __all__ = [ "DecisionEnvAdapter", "PPOConfig", "PPOTrainer", "train_ppo", + "TORCH_AVAILABLE", ] diff --git a/app/ui/views/backtest.py b/app/ui/views/backtest.py index afb8a2f..8290aac 100644 --- a/app/ui/views/backtest.py +++ b/app/ui/views/backtest.py @@ -19,6 +19,7 @@ from app.agents.game import Decision from app.agents.registry import default_agents from app.backtest.decision_env import DecisionEnv, ParameterSpec from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit +from app.rl import TORCH_AVAILABLE, DecisionEnvAdapter, PPOConfig, train_ppo from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig, run_backtest from app.ingest.checker import run_boot_check from app.ingest.tushare import run_ingestion @@ -38,6 +39,7 @@ from app.ui.views.dashboard import update_dashboard_sidebar _DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result" _DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results" _DECISION_ENV_BANDIT_RESULTS_KEY = "decision_env_bandit_results" +_DECISION_ENV_PPO_RESULTS_KEY = "decision_env_ppo_results" def render_backtest_review() -> None: """渲染回测执行、调参与结果复盘页面。""" @@ -675,4 +677,251 @@ def render_backtest_review() -> None: st.session_state.pop(_DECISION_ENV_BANDIT_RESULTS_KEY, None) st.success("已清除自动探索结果。") + st.divider() + st.subheader("PPO 训练(逐日强化学习)") + if TORCH_AVAILABLE: + col_ts, col_rollout, col_epochs = st.columns(3) + ppo_timesteps = int( + col_ts.number_input( + "总时间步", + min_value=256, + max_value=200_000, + value=4096, + step=256, + key="decision_env_ppo_timesteps", + ) + ) + ppo_rollout = int( + col_rollout.number_input( + "每次收集步数", + min_value=32, + max_value=2048, + value=256, + step=32, + key="decision_env_ppo_rollout", + ) + ) + ppo_epochs = int( + col_epochs.number_input( + "每批优化轮数", + min_value=1, + max_value=30, + value=8, + step=1, + key="decision_env_ppo_epochs", + ) + ) + + col_mb, col_gamma, col_lambda = st.columns(3) + ppo_minibatch = int( + col_mb.number_input( + "最小批次规模", + min_value=16, + max_value=1024, + value=128, + step=16, + key="decision_env_ppo_minibatch", + ) + ) + ppo_gamma = float( + col_gamma.number_input( + "折现系数 γ", + min_value=0.5, + max_value=0.999, + value=0.99, + step=0.01, + format="%.3f", + key="decision_env_ppo_gamma", + ) + ) + ppo_lambda = float( + col_lambda.number_input( + "GAE λ", + min_value=0.5, + max_value=0.999, + value=0.95, + step=0.01, + format="%.3f", + key="decision_env_ppo_lambda", + ) + ) + + col_clip, col_entropy, col_value = st.columns(3) + ppo_clip = float( + col_clip.number_input( + "裁剪范围 ε", + min_value=0.05, + max_value=0.5, + value=0.2, + step=0.01, + format="%.2f", + key="decision_env_ppo_clip", + ) + ) + ppo_entropy = float( + col_entropy.number_input( + "熵系数", + min_value=0.0, + max_value=0.1, + value=0.01, + step=0.005, + format="%.3f", + key="decision_env_ppo_entropy", + ) + ) + ppo_value_coef = float( + col_value.number_input( + "价值损失系数", + min_value=0.0, + max_value=2.0, + value=0.5, + step=0.1, + format="%.2f", + key="decision_env_ppo_value_coef", + ) + ) + + col_lr_p, col_lr_v, col_grad = st.columns(3) + ppo_policy_lr = float( + col_lr_p.number_input( + "策略学习率", + min_value=1e-5, + max_value=1e-2, + value=3e-4, + step=1e-5, + format="%.5f", + key="decision_env_ppo_policy_lr", + ) + ) + ppo_value_lr = float( + col_lr_v.number_input( + "价值学习率", + min_value=1e-5, + max_value=1e-2, + value=3e-4, + step=1e-5, + format="%.5f", + key="decision_env_ppo_value_lr", + ) + ) + ppo_max_grad_norm = float( + col_grad.number_input( + "梯度裁剪", value=0.5, min_value=0.0, max_value=5.0, step=0.1, format="%.1f", + key="decision_env_ppo_grad_norm", + ) + ) + + col_hidden, col_seed, _ = st.columns(3) + ppo_hidden_text = col_hidden.text_input( + "隐藏层结构 (逗号分隔)", value="128,128", key="decision_env_ppo_hidden" + ) + ppo_seed_text = col_seed.text_input( + "随机种子 (可选)", value="42", key="decision_env_ppo_seed" + ) + try: + ppo_hidden = tuple(int(v.strip()) for v in ppo_hidden_text.split(",") if v.strip()) + except ValueError: + ppo_hidden = () + ppo_seed = None + if ppo_seed_text.strip(): + try: + ppo_seed = int(ppo_seed_text.strip()) + except ValueError: + st.warning("PPO 随机种子需为整数,已忽略该值。") + ppo_seed = None + + if st.button("启动 PPO 训练", key="run_decision_env_ppo"): + if not specs: + st.warning("请先配置可调节参数,以构建动作空间。") + elif not ppo_hidden: + st.error("请提供合法的隐藏层结构,例如 128,128。") + else: + baseline_weights = app_cfg.agent_weights.as_dict() + for agent in agent_objects: + baseline_weights.setdefault(agent.name, 1.0) + + universe_env = [code.strip() for code in universe_text.split(',') if code.strip()] + if not universe_env: + st.error("请先指定至少一个股票代码。") + else: + bt_cfg_env = BtConfig( + id="decision_env_ppo", + name="DecisionEnv PPO", + start_date=start_date, + end_date=end_date, + universe=universe_env, + params={ + "target": target, + "stop": stop, + "hold_days": int(hold_days), + }, + method=app_cfg.decision_method, + ) + env = DecisionEnv( + bt_config=bt_cfg_env, + parameter_specs=specs, + baseline_weights=baseline_weights, + disable_departments=disable_departments, + ) + adapter = DecisionEnvAdapter(env) + config = PPOConfig( + total_timesteps=ppo_timesteps, + rollout_steps=ppo_rollout, + gamma=ppo_gamma, + gae_lambda=ppo_lambda, + clip_range=ppo_clip, + policy_lr=ppo_policy_lr, + value_lr=ppo_value_lr, + epochs=ppo_epochs, + minibatch_size=ppo_minibatch, + entropy_coef=ppo_entropy, + value_coef=ppo_value_coef, + max_grad_norm=ppo_max_grad_norm, + hidden_sizes=ppo_hidden, + seed=ppo_seed, + ) + with st.spinner("PPO 训练进行中,请耐心等待..."): + try: + summary = train_ppo(adapter, config) + except Exception as exc: # noqa: BLE001 + LOGGER.exception( + "PPO 训练失败", + extra={**LOG_EXTRA, "error": str(exc)}, + ) + st.error(f"PPO 训练失败:{exc}") + else: + payload = { + "timesteps": summary.timesteps, + "episode_rewards": summary.episode_rewards, + "episode_lengths": summary.episode_lengths, + "diagnostics": summary.diagnostics[-25:], + "observation_keys": adapter.keys(), + } + st.session_state[_DECISION_ENV_PPO_RESULTS_KEY] = payload + st.success("PPO 训练完成。") + + ppo_state = st.session_state.get(_DECISION_ENV_PPO_RESULTS_KEY) + if ppo_state: + st.caption( + f"最近一次 PPO 训练时间步:{ppo_state.get('timesteps')}" + ) + rewards = ppo_state.get("episode_rewards") or [] + if rewards: + st.line_chart(rewards, height=200) + lengths = ppo_state.get("episode_lengths") or [] + if lengths: + st.bar_chart(lengths, height=200) + diagnostics = ppo_state.get("diagnostics") or [] + if diagnostics: + st.dataframe(pd.DataFrame(diagnostics), hide_index=True, width='stretch') + st.download_button( + "下载 PPO 结果 (JSON)", + data=json.dumps(ppo_state, ensure_ascii=False, indent=2), + file_name="ppo_training_summary.json", + mime="application/json", + key="decision_env_ppo_json", + ) + if st.button("清除 PPO 训练结果", key="clear_decision_env_ppo"): + st.session_state.pop(_DECISION_ENV_PPO_RESULTS_KEY, None) + st.success("已清除 PPO 训练结果。")