add PPO training UI and torch optional dependency handling
This commit is contained in:
parent
6d3afcf555
commit
bc60a6115f
@ -1,11 +1,55 @@
|
||||
"""Reinforcement learning utilities for DecisionEnv."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
from .adapters import DecisionEnvAdapter
|
||||
|
||||
TORCH_AVAILABLE = True
|
||||
|
||||
try: # pragma: no cover - exercised via integration
|
||||
from .ppo import PPOConfig, PPOTrainer, train_ppo
|
||||
except ModuleNotFoundError as exc: # pragma: no cover - optional dependency guard
|
||||
if exc.name != "torch":
|
||||
raise
|
||||
TORCH_AVAILABLE = False
|
||||
|
||||
@dataclass
|
||||
class PPOConfig:
|
||||
"""Placeholder PPOConfig used when torch is unavailable."""
|
||||
|
||||
total_timesteps: int = 0
|
||||
rollout_steps: int = 0
|
||||
gamma: float = 0.0
|
||||
gae_lambda: float = 0.0
|
||||
clip_range: float = 0.0
|
||||
policy_lr: float = 0.0
|
||||
value_lr: float = 0.0
|
||||
epochs: int = 0
|
||||
minibatch_size: int = 0
|
||||
entropy_coef: float = 0.0
|
||||
value_coef: float = 0.0
|
||||
max_grad_norm: float = 0.0
|
||||
hidden_sizes: Sequence[int] = field(default_factory=tuple)
|
||||
seed: Optional[int] = None
|
||||
|
||||
class PPOTrainer: # pragma: no cover - simply raises when used
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
raise ModuleNotFoundError(
|
||||
"torch is required for PPO training. Please install torch before using this feature."
|
||||
)
|
||||
|
||||
def train_ppo(*_args, **_kwargs): # pragma: no cover - simply raises when used
|
||||
raise ModuleNotFoundError(
|
||||
"torch is required for PPO training. Please install torch before using this feature."
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DecisionEnvAdapter",
|
||||
"PPOConfig",
|
||||
"PPOTrainer",
|
||||
"train_ppo",
|
||||
"TORCH_AVAILABLE",
|
||||
]
|
||||
|
||||
@ -19,6 +19,7 @@ from app.agents.game import Decision
|
||||
from app.agents.registry import default_agents
|
||||
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
||||
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
|
||||
from app.rl import TORCH_AVAILABLE, DecisionEnvAdapter, PPOConfig, train_ppo
|
||||
from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig, run_backtest
|
||||
from app.ingest.checker import run_boot_check
|
||||
from app.ingest.tushare import run_ingestion
|
||||
@ -38,6 +39,7 @@ from app.ui.views.dashboard import update_dashboard_sidebar
|
||||
_DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result"
|
||||
_DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
|
||||
_DECISION_ENV_BANDIT_RESULTS_KEY = "decision_env_bandit_results"
|
||||
_DECISION_ENV_PPO_RESULTS_KEY = "decision_env_ppo_results"
|
||||
|
||||
def render_backtest_review() -> None:
|
||||
"""渲染回测执行、调参与结果复盘页面。"""
|
||||
@ -675,4 +677,251 @@ def render_backtest_review() -> None:
|
||||
st.session_state.pop(_DECISION_ENV_BANDIT_RESULTS_KEY, None)
|
||||
st.success("已清除自动探索结果。")
|
||||
|
||||
st.divider()
|
||||
st.subheader("PPO 训练(逐日强化学习)")
|
||||
|
||||
if TORCH_AVAILABLE:
|
||||
col_ts, col_rollout, col_epochs = st.columns(3)
|
||||
ppo_timesteps = int(
|
||||
col_ts.number_input(
|
||||
"总时间步",
|
||||
min_value=256,
|
||||
max_value=200_000,
|
||||
value=4096,
|
||||
step=256,
|
||||
key="decision_env_ppo_timesteps",
|
||||
)
|
||||
)
|
||||
ppo_rollout = int(
|
||||
col_rollout.number_input(
|
||||
"每次收集步数",
|
||||
min_value=32,
|
||||
max_value=2048,
|
||||
value=256,
|
||||
step=32,
|
||||
key="decision_env_ppo_rollout",
|
||||
)
|
||||
)
|
||||
ppo_epochs = int(
|
||||
col_epochs.number_input(
|
||||
"每批优化轮数",
|
||||
min_value=1,
|
||||
max_value=30,
|
||||
value=8,
|
||||
step=1,
|
||||
key="decision_env_ppo_epochs",
|
||||
)
|
||||
)
|
||||
|
||||
col_mb, col_gamma, col_lambda = st.columns(3)
|
||||
ppo_minibatch = int(
|
||||
col_mb.number_input(
|
||||
"最小批次规模",
|
||||
min_value=16,
|
||||
max_value=1024,
|
||||
value=128,
|
||||
step=16,
|
||||
key="decision_env_ppo_minibatch",
|
||||
)
|
||||
)
|
||||
ppo_gamma = float(
|
||||
col_gamma.number_input(
|
||||
"折现系数 γ",
|
||||
min_value=0.5,
|
||||
max_value=0.999,
|
||||
value=0.99,
|
||||
step=0.01,
|
||||
format="%.3f",
|
||||
key="decision_env_ppo_gamma",
|
||||
)
|
||||
)
|
||||
ppo_lambda = float(
|
||||
col_lambda.number_input(
|
||||
"GAE λ",
|
||||
min_value=0.5,
|
||||
max_value=0.999,
|
||||
value=0.95,
|
||||
step=0.01,
|
||||
format="%.3f",
|
||||
key="decision_env_ppo_lambda",
|
||||
)
|
||||
)
|
||||
|
||||
col_clip, col_entropy, col_value = st.columns(3)
|
||||
ppo_clip = float(
|
||||
col_clip.number_input(
|
||||
"裁剪范围 ε",
|
||||
min_value=0.05,
|
||||
max_value=0.5,
|
||||
value=0.2,
|
||||
step=0.01,
|
||||
format="%.2f",
|
||||
key="decision_env_ppo_clip",
|
||||
)
|
||||
)
|
||||
ppo_entropy = float(
|
||||
col_entropy.number_input(
|
||||
"熵系数",
|
||||
min_value=0.0,
|
||||
max_value=0.1,
|
||||
value=0.01,
|
||||
step=0.005,
|
||||
format="%.3f",
|
||||
key="decision_env_ppo_entropy",
|
||||
)
|
||||
)
|
||||
ppo_value_coef = float(
|
||||
col_value.number_input(
|
||||
"价值损失系数",
|
||||
min_value=0.0,
|
||||
max_value=2.0,
|
||||
value=0.5,
|
||||
step=0.1,
|
||||
format="%.2f",
|
||||
key="decision_env_ppo_value_coef",
|
||||
)
|
||||
)
|
||||
|
||||
col_lr_p, col_lr_v, col_grad = st.columns(3)
|
||||
ppo_policy_lr = float(
|
||||
col_lr_p.number_input(
|
||||
"策略学习率",
|
||||
min_value=1e-5,
|
||||
max_value=1e-2,
|
||||
value=3e-4,
|
||||
step=1e-5,
|
||||
format="%.5f",
|
||||
key="decision_env_ppo_policy_lr",
|
||||
)
|
||||
)
|
||||
ppo_value_lr = float(
|
||||
col_lr_v.number_input(
|
||||
"价值学习率",
|
||||
min_value=1e-5,
|
||||
max_value=1e-2,
|
||||
value=3e-4,
|
||||
step=1e-5,
|
||||
format="%.5f",
|
||||
key="decision_env_ppo_value_lr",
|
||||
)
|
||||
)
|
||||
ppo_max_grad_norm = float(
|
||||
col_grad.number_input(
|
||||
"梯度裁剪", value=0.5, min_value=0.0, max_value=5.0, step=0.1, format="%.1f",
|
||||
key="decision_env_ppo_grad_norm",
|
||||
)
|
||||
)
|
||||
|
||||
col_hidden, col_seed, _ = st.columns(3)
|
||||
ppo_hidden_text = col_hidden.text_input(
|
||||
"隐藏层结构 (逗号分隔)", value="128,128", key="decision_env_ppo_hidden"
|
||||
)
|
||||
ppo_seed_text = col_seed.text_input(
|
||||
"随机种子 (可选)", value="42", key="decision_env_ppo_seed"
|
||||
)
|
||||
try:
|
||||
ppo_hidden = tuple(int(v.strip()) for v in ppo_hidden_text.split(",") if v.strip())
|
||||
except ValueError:
|
||||
ppo_hidden = ()
|
||||
ppo_seed = None
|
||||
if ppo_seed_text.strip():
|
||||
try:
|
||||
ppo_seed = int(ppo_seed_text.strip())
|
||||
except ValueError:
|
||||
st.warning("PPO 随机种子需为整数,已忽略该值。")
|
||||
ppo_seed = None
|
||||
|
||||
if st.button("启动 PPO 训练", key="run_decision_env_ppo"):
|
||||
if not specs:
|
||||
st.warning("请先配置可调节参数,以构建动作空间。")
|
||||
elif not ppo_hidden:
|
||||
st.error("请提供合法的隐藏层结构,例如 128,128。")
|
||||
else:
|
||||
baseline_weights = app_cfg.agent_weights.as_dict()
|
||||
for agent in agent_objects:
|
||||
baseline_weights.setdefault(agent.name, 1.0)
|
||||
|
||||
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||
if not universe_env:
|
||||
st.error("请先指定至少一个股票代码。")
|
||||
else:
|
||||
bt_cfg_env = BtConfig(
|
||||
id="decision_env_ppo",
|
||||
name="DecisionEnv PPO",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
universe=universe_env,
|
||||
params={
|
||||
"target": target,
|
||||
"stop": stop,
|
||||
"hold_days": int(hold_days),
|
||||
},
|
||||
method=app_cfg.decision_method,
|
||||
)
|
||||
env = DecisionEnv(
|
||||
bt_config=bt_cfg_env,
|
||||
parameter_specs=specs,
|
||||
baseline_weights=baseline_weights,
|
||||
disable_departments=disable_departments,
|
||||
)
|
||||
adapter = DecisionEnvAdapter(env)
|
||||
config = PPOConfig(
|
||||
total_timesteps=ppo_timesteps,
|
||||
rollout_steps=ppo_rollout,
|
||||
gamma=ppo_gamma,
|
||||
gae_lambda=ppo_lambda,
|
||||
clip_range=ppo_clip,
|
||||
policy_lr=ppo_policy_lr,
|
||||
value_lr=ppo_value_lr,
|
||||
epochs=ppo_epochs,
|
||||
minibatch_size=ppo_minibatch,
|
||||
entropy_coef=ppo_entropy,
|
||||
value_coef=ppo_value_coef,
|
||||
max_grad_norm=ppo_max_grad_norm,
|
||||
hidden_sizes=ppo_hidden,
|
||||
seed=ppo_seed,
|
||||
)
|
||||
with st.spinner("PPO 训练进行中,请耐心等待..."):
|
||||
try:
|
||||
summary = train_ppo(adapter, config)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception(
|
||||
"PPO 训练失败",
|
||||
extra={**LOG_EXTRA, "error": str(exc)},
|
||||
)
|
||||
st.error(f"PPO 训练失败:{exc}")
|
||||
else:
|
||||
payload = {
|
||||
"timesteps": summary.timesteps,
|
||||
"episode_rewards": summary.episode_rewards,
|
||||
"episode_lengths": summary.episode_lengths,
|
||||
"diagnostics": summary.diagnostics[-25:],
|
||||
"observation_keys": adapter.keys(),
|
||||
}
|
||||
st.session_state[_DECISION_ENV_PPO_RESULTS_KEY] = payload
|
||||
st.success("PPO 训练完成。")
|
||||
|
||||
ppo_state = st.session_state.get(_DECISION_ENV_PPO_RESULTS_KEY)
|
||||
if ppo_state:
|
||||
st.caption(
|
||||
f"最近一次 PPO 训练时间步:{ppo_state.get('timesteps')}"
|
||||
)
|
||||
rewards = ppo_state.get("episode_rewards") or []
|
||||
if rewards:
|
||||
st.line_chart(rewards, height=200)
|
||||
lengths = ppo_state.get("episode_lengths") or []
|
||||
if lengths:
|
||||
st.bar_chart(lengths, height=200)
|
||||
diagnostics = ppo_state.get("diagnostics") or []
|
||||
if diagnostics:
|
||||
st.dataframe(pd.DataFrame(diagnostics), hide_index=True, width='stretch')
|
||||
st.download_button(
|
||||
"下载 PPO 结果 (JSON)",
|
||||
data=json.dumps(ppo_state, ensure_ascii=False, indent=2),
|
||||
file_name="ppo_training_summary.json",
|
||||
mime="application/json",
|
||||
key="decision_env_ppo_json",
|
||||
)
|
||||
if st.button("清除 PPO 训练结果", key="clear_decision_env_ppo"):
|
||||
st.session_state.pop(_DECISION_ENV_PPO_RESULTS_KEY, None)
|
||||
st.success("已清除 PPO 训练结果。")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user