add PPO training UI and torch optional dependency handling
This commit is contained in:
parent
6d3afcf555
commit
bc60a6115f
@ -1,11 +1,55 @@
|
|||||||
"""Reinforcement learning utilities for DecisionEnv."""
|
"""Reinforcement learning utilities for DecisionEnv."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from .adapters import DecisionEnvAdapter
|
from .adapters import DecisionEnvAdapter
|
||||||
from .ppo import PPOConfig, PPOTrainer, train_ppo
|
|
||||||
|
TORCH_AVAILABLE = True
|
||||||
|
|
||||||
|
try: # pragma: no cover - exercised via integration
|
||||||
|
from .ppo import PPOConfig, PPOTrainer, train_ppo
|
||||||
|
except ModuleNotFoundError as exc: # pragma: no cover - optional dependency guard
|
||||||
|
if exc.name != "torch":
|
||||||
|
raise
|
||||||
|
TORCH_AVAILABLE = False
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PPOConfig:
|
||||||
|
"""Placeholder PPOConfig used when torch is unavailable."""
|
||||||
|
|
||||||
|
total_timesteps: int = 0
|
||||||
|
rollout_steps: int = 0
|
||||||
|
gamma: float = 0.0
|
||||||
|
gae_lambda: float = 0.0
|
||||||
|
clip_range: float = 0.0
|
||||||
|
policy_lr: float = 0.0
|
||||||
|
value_lr: float = 0.0
|
||||||
|
epochs: int = 0
|
||||||
|
minibatch_size: int = 0
|
||||||
|
entropy_coef: float = 0.0
|
||||||
|
value_coef: float = 0.0
|
||||||
|
max_grad_norm: float = 0.0
|
||||||
|
hidden_sizes: Sequence[int] = field(default_factory=tuple)
|
||||||
|
seed: Optional[int] = None
|
||||||
|
|
||||||
|
class PPOTrainer: # pragma: no cover - simply raises when used
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"torch is required for PPO training. Please install torch before using this feature."
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_ppo(*_args, **_kwargs): # pragma: no cover - simply raises when used
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"torch is required for PPO training. Please install torch before using this feature."
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DecisionEnvAdapter",
|
"DecisionEnvAdapter",
|
||||||
"PPOConfig",
|
"PPOConfig",
|
||||||
"PPOTrainer",
|
"PPOTrainer",
|
||||||
"train_ppo",
|
"train_ppo",
|
||||||
|
"TORCH_AVAILABLE",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from app.agents.game import Decision
|
|||||||
from app.agents.registry import default_agents
|
from app.agents.registry import default_agents
|
||||||
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
||||||
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
|
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
|
||||||
|
from app.rl import TORCH_AVAILABLE, DecisionEnvAdapter, PPOConfig, train_ppo
|
||||||
from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig, run_backtest
|
from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig, run_backtest
|
||||||
from app.ingest.checker import run_boot_check
|
from app.ingest.checker import run_boot_check
|
||||||
from app.ingest.tushare import run_ingestion
|
from app.ingest.tushare import run_ingestion
|
||||||
@ -38,6 +39,7 @@ from app.ui.views.dashboard import update_dashboard_sidebar
|
|||||||
_DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result"
|
_DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result"
|
||||||
_DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
|
_DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
|
||||||
_DECISION_ENV_BANDIT_RESULTS_KEY = "decision_env_bandit_results"
|
_DECISION_ENV_BANDIT_RESULTS_KEY = "decision_env_bandit_results"
|
||||||
|
_DECISION_ENV_PPO_RESULTS_KEY = "decision_env_ppo_results"
|
||||||
|
|
||||||
def render_backtest_review() -> None:
|
def render_backtest_review() -> None:
|
||||||
"""渲染回测执行、调参与结果复盘页面。"""
|
"""渲染回测执行、调参与结果复盘页面。"""
|
||||||
@ -675,4 +677,251 @@ def render_backtest_review() -> None:
|
|||||||
st.session_state.pop(_DECISION_ENV_BANDIT_RESULTS_KEY, None)
|
st.session_state.pop(_DECISION_ENV_BANDIT_RESULTS_KEY, None)
|
||||||
st.success("已清除自动探索结果。")
|
st.success("已清除自动探索结果。")
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
st.subheader("PPO 训练(逐日强化学习)")
|
||||||
|
|
||||||
|
if TORCH_AVAILABLE:
|
||||||
|
col_ts, col_rollout, col_epochs = st.columns(3)
|
||||||
|
ppo_timesteps = int(
|
||||||
|
col_ts.number_input(
|
||||||
|
"总时间步",
|
||||||
|
min_value=256,
|
||||||
|
max_value=200_000,
|
||||||
|
value=4096,
|
||||||
|
step=256,
|
||||||
|
key="decision_env_ppo_timesteps",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ppo_rollout = int(
|
||||||
|
col_rollout.number_input(
|
||||||
|
"每次收集步数",
|
||||||
|
min_value=32,
|
||||||
|
max_value=2048,
|
||||||
|
value=256,
|
||||||
|
step=32,
|
||||||
|
key="decision_env_ppo_rollout",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ppo_epochs = int(
|
||||||
|
col_epochs.number_input(
|
||||||
|
"每批优化轮数",
|
||||||
|
min_value=1,
|
||||||
|
max_value=30,
|
||||||
|
value=8,
|
||||||
|
step=1,
|
||||||
|
key="decision_env_ppo_epochs",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
col_mb, col_gamma, col_lambda = st.columns(3)
|
||||||
|
ppo_minibatch = int(
|
||||||
|
col_mb.number_input(
|
||||||
|
"最小批次规模",
|
||||||
|
min_value=16,
|
||||||
|
max_value=1024,
|
||||||
|
value=128,
|
||||||
|
step=16,
|
||||||
|
key="decision_env_ppo_minibatch",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ppo_gamma = float(
|
||||||
|
col_gamma.number_input(
|
||||||
|
"折现系数 γ",
|
||||||
|
min_value=0.5,
|
||||||
|
max_value=0.999,
|
||||||
|
value=0.99,
|
||||||
|
step=0.01,
|
||||||
|
format="%.3f",
|
||||||
|
key="decision_env_ppo_gamma",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ppo_lambda = float(
|
||||||
|
col_lambda.number_input(
|
||||||
|
"GAE λ",
|
||||||
|
min_value=0.5,
|
||||||
|
max_value=0.999,
|
||||||
|
value=0.95,
|
||||||
|
step=0.01,
|
||||||
|
format="%.3f",
|
||||||
|
key="decision_env_ppo_lambda",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
col_clip, col_entropy, col_value = st.columns(3)
|
||||||
|
ppo_clip = float(
|
||||||
|
col_clip.number_input(
|
||||||
|
"裁剪范围 ε",
|
||||||
|
min_value=0.05,
|
||||||
|
max_value=0.5,
|
||||||
|
value=0.2,
|
||||||
|
step=0.01,
|
||||||
|
format="%.2f",
|
||||||
|
key="decision_env_ppo_clip",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ppo_entropy = float(
|
||||||
|
col_entropy.number_input(
|
||||||
|
"熵系数",
|
||||||
|
min_value=0.0,
|
||||||
|
max_value=0.1,
|
||||||
|
value=0.01,
|
||||||
|
step=0.005,
|
||||||
|
format="%.3f",
|
||||||
|
key="decision_env_ppo_entropy",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ppo_value_coef = float(
|
||||||
|
col_value.number_input(
|
||||||
|
"价值损失系数",
|
||||||
|
min_value=0.0,
|
||||||
|
max_value=2.0,
|
||||||
|
value=0.5,
|
||||||
|
step=0.1,
|
||||||
|
format="%.2f",
|
||||||
|
key="decision_env_ppo_value_coef",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
col_lr_p, col_lr_v, col_grad = st.columns(3)
|
||||||
|
ppo_policy_lr = float(
|
||||||
|
col_lr_p.number_input(
|
||||||
|
"策略学习率",
|
||||||
|
min_value=1e-5,
|
||||||
|
max_value=1e-2,
|
||||||
|
value=3e-4,
|
||||||
|
step=1e-5,
|
||||||
|
format="%.5f",
|
||||||
|
key="decision_env_ppo_policy_lr",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ppo_value_lr = float(
|
||||||
|
col_lr_v.number_input(
|
||||||
|
"价值学习率",
|
||||||
|
min_value=1e-5,
|
||||||
|
max_value=1e-2,
|
||||||
|
value=3e-4,
|
||||||
|
step=1e-5,
|
||||||
|
format="%.5f",
|
||||||
|
key="decision_env_ppo_value_lr",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ppo_max_grad_norm = float(
|
||||||
|
col_grad.number_input(
|
||||||
|
"梯度裁剪", value=0.5, min_value=0.0, max_value=5.0, step=0.1, format="%.1f",
|
||||||
|
key="decision_env_ppo_grad_norm",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
col_hidden, col_seed, _ = st.columns(3)
|
||||||
|
ppo_hidden_text = col_hidden.text_input(
|
||||||
|
"隐藏层结构 (逗号分隔)", value="128,128", key="decision_env_ppo_hidden"
|
||||||
|
)
|
||||||
|
ppo_seed_text = col_seed.text_input(
|
||||||
|
"随机种子 (可选)", value="42", key="decision_env_ppo_seed"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
ppo_hidden = tuple(int(v.strip()) for v in ppo_hidden_text.split(",") if v.strip())
|
||||||
|
except ValueError:
|
||||||
|
ppo_hidden = ()
|
||||||
|
ppo_seed = None
|
||||||
|
if ppo_seed_text.strip():
|
||||||
|
try:
|
||||||
|
ppo_seed = int(ppo_seed_text.strip())
|
||||||
|
except ValueError:
|
||||||
|
st.warning("PPO 随机种子需为整数,已忽略该值。")
|
||||||
|
ppo_seed = None
|
||||||
|
|
||||||
|
if st.button("启动 PPO 训练", key="run_decision_env_ppo"):
|
||||||
|
if not specs:
|
||||||
|
st.warning("请先配置可调节参数,以构建动作空间。")
|
||||||
|
elif not ppo_hidden:
|
||||||
|
st.error("请提供合法的隐藏层结构,例如 128,128。")
|
||||||
|
else:
|
||||||
|
baseline_weights = app_cfg.agent_weights.as_dict()
|
||||||
|
for agent in agent_objects:
|
||||||
|
baseline_weights.setdefault(agent.name, 1.0)
|
||||||
|
|
||||||
|
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||||
|
if not universe_env:
|
||||||
|
st.error("请先指定至少一个股票代码。")
|
||||||
|
else:
|
||||||
|
bt_cfg_env = BtConfig(
|
||||||
|
id="decision_env_ppo",
|
||||||
|
name="DecisionEnv PPO",
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
universe=universe_env,
|
||||||
|
params={
|
||||||
|
"target": target,
|
||||||
|
"stop": stop,
|
||||||
|
"hold_days": int(hold_days),
|
||||||
|
},
|
||||||
|
method=app_cfg.decision_method,
|
||||||
|
)
|
||||||
|
env = DecisionEnv(
|
||||||
|
bt_config=bt_cfg_env,
|
||||||
|
parameter_specs=specs,
|
||||||
|
baseline_weights=baseline_weights,
|
||||||
|
disable_departments=disable_departments,
|
||||||
|
)
|
||||||
|
adapter = DecisionEnvAdapter(env)
|
||||||
|
config = PPOConfig(
|
||||||
|
total_timesteps=ppo_timesteps,
|
||||||
|
rollout_steps=ppo_rollout,
|
||||||
|
gamma=ppo_gamma,
|
||||||
|
gae_lambda=ppo_lambda,
|
||||||
|
clip_range=ppo_clip,
|
||||||
|
policy_lr=ppo_policy_lr,
|
||||||
|
value_lr=ppo_value_lr,
|
||||||
|
epochs=ppo_epochs,
|
||||||
|
minibatch_size=ppo_minibatch,
|
||||||
|
entropy_coef=ppo_entropy,
|
||||||
|
value_coef=ppo_value_coef,
|
||||||
|
max_grad_norm=ppo_max_grad_norm,
|
||||||
|
hidden_sizes=ppo_hidden,
|
||||||
|
seed=ppo_seed,
|
||||||
|
)
|
||||||
|
with st.spinner("PPO 训练进行中,请耐心等待..."):
|
||||||
|
try:
|
||||||
|
summary = train_ppo(adapter, config)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.exception(
|
||||||
|
"PPO 训练失败",
|
||||||
|
extra={**LOG_EXTRA, "error": str(exc)},
|
||||||
|
)
|
||||||
|
st.error(f"PPO 训练失败:{exc}")
|
||||||
|
else:
|
||||||
|
payload = {
|
||||||
|
"timesteps": summary.timesteps,
|
||||||
|
"episode_rewards": summary.episode_rewards,
|
||||||
|
"episode_lengths": summary.episode_lengths,
|
||||||
|
"diagnostics": summary.diagnostics[-25:],
|
||||||
|
"observation_keys": adapter.keys(),
|
||||||
|
}
|
||||||
|
st.session_state[_DECISION_ENV_PPO_RESULTS_KEY] = payload
|
||||||
|
st.success("PPO 训练完成。")
|
||||||
|
|
||||||
|
ppo_state = st.session_state.get(_DECISION_ENV_PPO_RESULTS_KEY)
|
||||||
|
if ppo_state:
|
||||||
|
st.caption(
|
||||||
|
f"最近一次 PPO 训练时间步:{ppo_state.get('timesteps')}"
|
||||||
|
)
|
||||||
|
rewards = ppo_state.get("episode_rewards") or []
|
||||||
|
if rewards:
|
||||||
|
st.line_chart(rewards, height=200)
|
||||||
|
lengths = ppo_state.get("episode_lengths") or []
|
||||||
|
if lengths:
|
||||||
|
st.bar_chart(lengths, height=200)
|
||||||
|
diagnostics = ppo_state.get("diagnostics") or []
|
||||||
|
if diagnostics:
|
||||||
|
st.dataframe(pd.DataFrame(diagnostics), hide_index=True, width='stretch')
|
||||||
|
st.download_button(
|
||||||
|
"下载 PPO 结果 (JSON)",
|
||||||
|
data=json.dumps(ppo_state, ensure_ascii=False, indent=2),
|
||||||
|
file_name="ppo_training_summary.json",
|
||||||
|
mime="application/json",
|
||||||
|
key="decision_env_ppo_json",
|
||||||
|
)
|
||||||
|
if st.button("清除 PPO 训练结果", key="clear_decision_env_ppo"):
|
||||||
|
st.session_state.pop(_DECISION_ENV_PPO_RESULTS_KEY, None)
|
||||||
|
st.success("已清除 PPO 训练结果。")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user