add PPO training UI and torch optional dependency handling

This commit is contained in:
sam 2025-10-06 22:00:24 +08:00
parent 6d3afcf555
commit bc60a6115f
2 changed files with 294 additions and 1 deletions

View File

@ -1,11 +1,55 @@
"""Reinforcement learning utilities for DecisionEnv."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Sequence
from .adapters import DecisionEnvAdapter
TORCH_AVAILABLE = True
try: # pragma: no cover - exercised via integration
from .ppo import PPOConfig, PPOTrainer, train_ppo
except ModuleNotFoundError as exc: # pragma: no cover - optional dependency guard
if exc.name != "torch":
raise
TORCH_AVAILABLE = False
@dataclass
class PPOConfig:
"""Placeholder PPOConfig used when torch is unavailable."""
total_timesteps: int = 0
rollout_steps: int = 0
gamma: float = 0.0
gae_lambda: float = 0.0
clip_range: float = 0.0
policy_lr: float = 0.0
value_lr: float = 0.0
epochs: int = 0
minibatch_size: int = 0
entropy_coef: float = 0.0
value_coef: float = 0.0
max_grad_norm: float = 0.0
hidden_sizes: Sequence[int] = field(default_factory=tuple)
seed: Optional[int] = None
class PPOTrainer: # pragma: no cover - simply raises when used
def __init__(self, *args, **kwargs) -> None:
raise ModuleNotFoundError(
"torch is required for PPO training. Please install torch before using this feature."
)
def train_ppo(*_args, **_kwargs): # pragma: no cover - simply raises when used
raise ModuleNotFoundError(
"torch is required for PPO training. Please install torch before using this feature."
)
__all__ = [
"DecisionEnvAdapter",
"PPOConfig",
"PPOTrainer",
"train_ppo",
"TORCH_AVAILABLE",
]

View File

@ -19,6 +19,7 @@ from app.agents.game import Decision
from app.agents.registry import default_agents
from app.backtest.decision_env import DecisionEnv, ParameterSpec
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
from app.rl import TORCH_AVAILABLE, DecisionEnvAdapter, PPOConfig, train_ppo
from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig, run_backtest
from app.ingest.checker import run_boot_check
from app.ingest.tushare import run_ingestion
@ -38,6 +39,7 @@ from app.ui.views.dashboard import update_dashboard_sidebar
_DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result"
_DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
_DECISION_ENV_BANDIT_RESULTS_KEY = "decision_env_bandit_results"
_DECISION_ENV_PPO_RESULTS_KEY = "decision_env_ppo_results"
def render_backtest_review() -> None:
"""渲染回测执行、调参与结果复盘页面。"""
@ -675,4 +677,251 @@ def render_backtest_review() -> None:
st.session_state.pop(_DECISION_ENV_BANDIT_RESULTS_KEY, None)
st.success("已清除自动探索结果。")
st.divider()
st.subheader("PPO 训练(逐日强化学习)")
if TORCH_AVAILABLE:
col_ts, col_rollout, col_epochs = st.columns(3)
ppo_timesteps = int(
col_ts.number_input(
"总时间步",
min_value=256,
max_value=200_000,
value=4096,
step=256,
key="decision_env_ppo_timesteps",
)
)
ppo_rollout = int(
col_rollout.number_input(
"每次收集步数",
min_value=32,
max_value=2048,
value=256,
step=32,
key="decision_env_ppo_rollout",
)
)
ppo_epochs = int(
col_epochs.number_input(
"每批优化轮数",
min_value=1,
max_value=30,
value=8,
step=1,
key="decision_env_ppo_epochs",
)
)
col_mb, col_gamma, col_lambda = st.columns(3)
ppo_minibatch = int(
col_mb.number_input(
"最小批次规模",
min_value=16,
max_value=1024,
value=128,
step=16,
key="decision_env_ppo_minibatch",
)
)
ppo_gamma = float(
col_gamma.number_input(
"折现系数 γ",
min_value=0.5,
max_value=0.999,
value=0.99,
step=0.01,
format="%.3f",
key="decision_env_ppo_gamma",
)
)
ppo_lambda = float(
col_lambda.number_input(
"GAE λ",
min_value=0.5,
max_value=0.999,
value=0.95,
step=0.01,
format="%.3f",
key="decision_env_ppo_lambda",
)
)
col_clip, col_entropy, col_value = st.columns(3)
ppo_clip = float(
col_clip.number_input(
"裁剪范围 ε",
min_value=0.05,
max_value=0.5,
value=0.2,
step=0.01,
format="%.2f",
key="decision_env_ppo_clip",
)
)
ppo_entropy = float(
col_entropy.number_input(
"熵系数",
min_value=0.0,
max_value=0.1,
value=0.01,
step=0.005,
format="%.3f",
key="decision_env_ppo_entropy",
)
)
ppo_value_coef = float(
col_value.number_input(
"价值损失系数",
min_value=0.0,
max_value=2.0,
value=0.5,
step=0.1,
format="%.2f",
key="decision_env_ppo_value_coef",
)
)
col_lr_p, col_lr_v, col_grad = st.columns(3)
ppo_policy_lr = float(
col_lr_p.number_input(
"策略学习率",
min_value=1e-5,
max_value=1e-2,
value=3e-4,
step=1e-5,
format="%.5f",
key="decision_env_ppo_policy_lr",
)
)
ppo_value_lr = float(
col_lr_v.number_input(
"价值学习率",
min_value=1e-5,
max_value=1e-2,
value=3e-4,
step=1e-5,
format="%.5f",
key="decision_env_ppo_value_lr",
)
)
ppo_max_grad_norm = float(
col_grad.number_input(
"梯度裁剪", value=0.5, min_value=0.0, max_value=5.0, step=0.1, format="%.1f",
key="decision_env_ppo_grad_norm",
)
)
col_hidden, col_seed, _ = st.columns(3)
ppo_hidden_text = col_hidden.text_input(
"隐藏层结构 (逗号分隔)", value="128,128", key="decision_env_ppo_hidden"
)
ppo_seed_text = col_seed.text_input(
"随机种子 (可选)", value="42", key="decision_env_ppo_seed"
)
try:
ppo_hidden = tuple(int(v.strip()) for v in ppo_hidden_text.split(",") if v.strip())
except ValueError:
ppo_hidden = ()
ppo_seed = None
if ppo_seed_text.strip():
try:
ppo_seed = int(ppo_seed_text.strip())
except ValueError:
st.warning("PPO 随机种子需为整数,已忽略该值。")
ppo_seed = None
if st.button("启动 PPO 训练", key="run_decision_env_ppo"):
if not specs:
st.warning("请先配置可调节参数,以构建动作空间。")
elif not ppo_hidden:
st.error("请提供合法的隐藏层结构,例如 128,128。")
else:
baseline_weights = app_cfg.agent_weights.as_dict()
for agent in agent_objects:
baseline_weights.setdefault(agent.name, 1.0)
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
if not universe_env:
st.error("请先指定至少一个股票代码。")
else:
bt_cfg_env = BtConfig(
id="decision_env_ppo",
name="DecisionEnv PPO",
start_date=start_date,
end_date=end_date,
universe=universe_env,
params={
"target": target,
"stop": stop,
"hold_days": int(hold_days),
},
method=app_cfg.decision_method,
)
env = DecisionEnv(
bt_config=bt_cfg_env,
parameter_specs=specs,
baseline_weights=baseline_weights,
disable_departments=disable_departments,
)
adapter = DecisionEnvAdapter(env)
config = PPOConfig(
total_timesteps=ppo_timesteps,
rollout_steps=ppo_rollout,
gamma=ppo_gamma,
gae_lambda=ppo_lambda,
clip_range=ppo_clip,
policy_lr=ppo_policy_lr,
value_lr=ppo_value_lr,
epochs=ppo_epochs,
minibatch_size=ppo_minibatch,
entropy_coef=ppo_entropy,
value_coef=ppo_value_coef,
max_grad_norm=ppo_max_grad_norm,
hidden_sizes=ppo_hidden,
seed=ppo_seed,
)
with st.spinner("PPO 训练进行中,请耐心等待..."):
try:
summary = train_ppo(adapter, config)
except Exception as exc: # noqa: BLE001
LOGGER.exception(
"PPO 训练失败",
extra={**LOG_EXTRA, "error": str(exc)},
)
st.error(f"PPO 训练失败:{exc}")
else:
payload = {
"timesteps": summary.timesteps,
"episode_rewards": summary.episode_rewards,
"episode_lengths": summary.episode_lengths,
"diagnostics": summary.diagnostics[-25:],
"observation_keys": adapter.keys(),
}
st.session_state[_DECISION_ENV_PPO_RESULTS_KEY] = payload
st.success("PPO 训练完成。")
ppo_state = st.session_state.get(_DECISION_ENV_PPO_RESULTS_KEY)
if ppo_state:
st.caption(
f"最近一次 PPO 训练时间步:{ppo_state.get('timesteps')}"
)
rewards = ppo_state.get("episode_rewards") or []
if rewards:
st.line_chart(rewards, height=200)
lengths = ppo_state.get("episode_lengths") or []
if lengths:
st.bar_chart(lengths, height=200)
diagnostics = ppo_state.get("diagnostics") or []
if diagnostics:
st.dataframe(pd.DataFrame(diagnostics), hide_index=True, width='stretch')
st.download_button(
"下载 PPO 结果 (JSON)",
data=json.dumps(ppo_state, ensure_ascii=False, indent=2),
file_name="ppo_training_summary.json",
mime="application/json",
key="decision_env_ppo_json",
)
if st.button("清除 PPO 训练结果", key="clear_decision_env_ppo"):
st.session_state.pop(_DECISION_ENV_PPO_RESULTS_KEY, None)
st.success("已清除 PPO 训练结果。")