This commit is contained in:
sam 2025-10-06 13:28:49 +08:00
parent d0a0340db6
commit fa46be501b
3 changed files with 201 additions and 7 deletions

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import random import random
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, Iterable, List, Sequence, Tuple from typing import Any, Dict, Iterable, List, Mapping, Sequence, Tuple
from app.backtest.decision_env import DecisionEnv, EpisodeMetrics from app.backtest.decision_env import DecisionEnv, EpisodeMetrics
from app.backtest.decision_env import ParameterSpec from app.backtest.decision_env import ParameterSpec
@ -28,9 +28,12 @@ class BanditConfig:
@dataclass @dataclass
class BanditEpisode: class BanditEpisode:
action: Dict[str, float] action: Dict[str, float]
resolved_action: Dict[str, Any]
reward: float reward: float
metrics: EpisodeMetrics metrics: EpisodeMetrics
observation: Dict[str, float] observation: Dict[str, float]
weights: Mapping[str, float] | None = None
department_controls: Mapping[str, Mapping[str, Any]] | None = None
@dataclass @dataclass
@ -78,8 +81,13 @@ class EpsilonGreedyBandit:
self._counts[key] = count self._counts[key] = count
self._value_estimates[key] = old_estimate + (reward - old_estimate) / count self._value_estimates[key] = old_estimate + (reward - old_estimate) / count
action_payload = self._action_to_mapping(action) action_payload = self._raw_action_mapping(action)
resolved_action = self._resolved_action_mapping(action)
metrics_payload = _metrics_to_dict(metrics) metrics_payload = _metrics_to_dict(metrics)
department_controls = info.get("department_controls")
if department_controls:
metrics_payload["department_controls"] = department_controls
metrics_payload["resolved_action"] = resolved_action
try: try:
log_tuning_result( log_tuning_result(
experiment_id=self.config.experiment_id, experiment_id=self.config.experiment_id,
@ -94,9 +102,12 @@ class EpsilonGreedyBandit:
episode_record = BanditEpisode( episode_record = BanditEpisode(
action=action_payload, action=action_payload,
resolved_action=resolved_action,
reward=reward, reward=reward,
metrics=metrics, metrics=metrics,
observation=obs, observation=obs,
weights=info.get("weights"),
department_controls=department_controls,
) )
self._history.episodes.append(episode_record) self._history.episodes.append(episode_record)
LOGGER.info( LOGGER.info(
@ -112,17 +123,28 @@ class EpsilonGreedyBandit:
if self._value_estimates and self._random.random() > self.config.epsilon: if self._value_estimates and self._random.random() > self.config.epsilon:
best = max(self._value_estimates.items(), key=lambda item: item[1])[0] best = max(self._value_estimates.items(), key=lambda item: item[1])[0]
return list(best) return list(best)
return [ return [self._sample_value(spec) for spec in self._specs]
self._random.uniform(spec.minimum, spec.maximum)
for spec in self._specs
]
def _action_to_mapping(self, action: Sequence[float]) -> Dict[str, float]: def _raw_action_mapping(self, action: Sequence[float]) -> Dict[str, float]:
return { return {
spec.name: float(value) spec.name: float(value)
for spec, value in zip(self._specs, action, strict=True) for spec, value in zip(self._specs, action, strict=True)
} }
def _resolved_action_mapping(self, action: Sequence[float]) -> Dict[str, Any]:
return {
spec.name: spec.resolve(value)
for spec, value in zip(self._specs, action, strict=True)
}
def _sample_value(self, spec: ParameterSpec) -> float:
if spec.values:
if len(spec.values) <= 1:
return 0.0
index = self._random.randrange(len(spec.values))
return index / (len(spec.values) - 1)
return self._random.random()
def _metrics_to_dict(metrics: EpisodeMetrics) -> Dict[str, float | Dict[str, int]]: def _metrics_to_dict(metrics: EpisodeMetrics) -> Dict[str, float | Dict[str, int]]:
payload: Dict[str, float | Dict[str, int]] = { payload: Dict[str, float | Dict[str, int]] = {

View File

@ -18,6 +18,7 @@ from app.agents.base import AgentContext
from app.agents.game import Decision from app.agents.game import Decision
from app.agents.registry import default_agents from app.agents.registry import default_agents
from app.backtest.decision_env import DecisionEnv, ParameterSpec from app.backtest.decision_env import DecisionEnv, ParameterSpec
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig, run_backtest from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig, run_backtest
from app.ingest.checker import run_boot_check from app.ingest.checker import run_boot_check
from app.ingest.tushare import run_ingestion from app.ingest.tushare import run_ingestion
@ -35,6 +36,7 @@ from app.ui.views.dashboard import update_dashboard_sidebar
_DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result" _DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result"
_DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results" _DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
_DECISION_ENV_BANDIT_RESULTS_KEY = "decision_env_bandit_results"
def render_backtest_review() -> None: def render_backtest_review() -> None:
"""渲染回测执行、调参与结果复盘页面。""" """渲染回测执行、调参与结果复盘页面。"""
@ -675,6 +677,170 @@ def render_backtest_review() -> None:
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None) st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
st.success("已清除单次调参结果缓存。") st.success("已清除单次调参结果缓存。")
st.divider()
st.subheader("自动探索epsilon-greedy")
col_ep, col_eps, col_seed = st.columns([1, 1, 1])
bandit_episodes = int(
col_ep.number_input(
"迭代次数",
min_value=1,
max_value=200,
value=10,
step=1,
key="decision_env_bandit_episodes",
help="探索的回合数,越大越充分但耗时越久。",
)
)
bandit_epsilon = float(
col_eps.slider(
"探索比例 ε",
min_value=0.0,
max_value=1.0,
value=0.2,
step=0.05,
key="decision_env_bandit_epsilon",
help="ε 越大,随机探索概率越高。",
)
)
seed_text = col_seed.text_input(
"随机种子(可选)",
value="",
key="decision_env_bandit_seed",
help="填写整数可复现实验,不填写则随机。",
).strip()
bandit_seed = None
if seed_text:
try:
bandit_seed = int(seed_text)
except ValueError:
st.warning("随机种子需为整数,已忽略该值。")
bandit_seed = None
run_bandit = st.button("执行自动探索", key="run_decision_env_bandit")
if run_bandit:
if not specs:
st.warning("请至少配置一个动作维度再执行探索。")
elif selected_agents and not range_valid:
st.error("请确保所有代理的最大权重大于最小权重。")
elif not controls_valid:
st.error("请修正部门参数的取值范围。")
else:
baseline_weights = cfg.agent_weights.as_dict()
for agent in agent_objects:
baseline_weights.setdefault(agent.name, 1.0)
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
if not universe_env:
st.error("请先指定至少一个股票代码。")
else:
bt_cfg_env = BtConfig(
id="decision_env_bandit",
name="DecisionEnv Bandit",
start_date=start_date,
end_date=end_date,
universe=universe_env,
params={
"target": target,
"stop": stop,
"hold_days": int(hold_days),
},
method=cfg.decision_method,
)
env = DecisionEnv(
bt_config=bt_cfg_env,
parameter_specs=specs,
baseline_weights=baseline_weights,
disable_departments=disable_departments,
)
config = BanditConfig(
experiment_id=experiment_id or f"bandit_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
strategy=strategy_label or "DecisionEnv",
episodes=bandit_episodes,
epsilon=bandit_epsilon,
seed=bandit_seed,
)
bandit = EpsilonGreedyBandit(env, config)
with st.spinner("自动探索进行中,请稍候..."):
summary = bandit.run()
episodes_dump: List[Dict[str, object]] = []
for idx, episode in enumerate(summary.episodes, start=1):
episodes_dump.append(
{
"序号": idx,
"奖励": episode.reward,
"动作(raw)": json.dumps(episode.action, ensure_ascii=False),
"参数值": json.dumps(episode.resolved_action, ensure_ascii=False),
"总收益": episode.metrics.total_return,
"最大回撤": episode.metrics.max_drawdown,
"波动率": episode.metrics.volatility,
"权重": json.dumps(episode.weights or {}, ensure_ascii=False),
"部门控制": json.dumps(episode.department_controls or {}, ensure_ascii=False),
}
)
best_episode = summary.best_episode
best_index = summary.episodes.index(best_episode) + 1 if best_episode else None
st.session_state[_DECISION_ENV_BANDIT_RESULTS_KEY] = {
"episodes": episodes_dump,
"best_index": best_index,
"best": {
"reward": best_episode.reward if best_episode else None,
"action": best_episode.action if best_episode else None,
"resolved_action": best_episode.resolved_action if best_episode else None,
"weights": best_episode.weights if best_episode else None,
"department_controls": best_episode.department_controls if best_episode else None,
},
"experiment_id": config.experiment_id,
"strategy": config.strategy,
}
st.success(f"自动探索完成,共执行 {len(episodes_dump)} 轮。")
bandit_state = st.session_state.get(_DECISION_ENV_BANDIT_RESULTS_KEY)
if bandit_state:
st.caption(
f"实验 ID{bandit_state.get('experiment_id')} | 策略:{bandit_state.get('strategy')}"
)
episodes_dump = bandit_state.get("episodes") or []
if episodes_dump:
st.dataframe(pd.DataFrame(episodes_dump), hide_index=True, width='stretch')
best_payload = bandit_state.get("best") or {}
if best_payload.get("reward") is not None:
st.success(
f"最佳结果:第 {bandit_state.get('best_index')} 轮,奖励 {best_payload['reward']:+.4f}"
)
col_best1, col_best2 = st.columns(2)
col_best1.write("动作(raw)")
col_best1.json(best_payload.get("action") or {})
col_best2.write("参数值:")
col_best2.json(best_payload.get("resolved_action") or {})
weights_payload = best_payload.get("weights") or {}
if weights_payload:
st.write("对应代理权重:")
st.json(weights_payload)
if st.button(
"将最佳权重写入默认配置",
key="save_decision_env_bandit_weights",
):
try:
cfg.agent_weights.update_from_dict(weights_payload)
save_config(cfg)
except Exception as exc: # noqa: BLE001
LOGGER.exception(
"保存 bandit 权重失败",
extra={**LOG_EXTRA, "error": str(exc)},
)
st.error(f"写入配置失败:{exc}")
else:
st.success("最佳权重已写入 config.json")
dept_ctrl = best_payload.get("department_controls") or {}
if dept_ctrl:
with st.expander("最佳部门控制参数", expanded=False):
st.json(dept_ctrl)
if st.button("清除自动探索结果", key="clear_decision_env_bandit"):
st.session_state.pop(_DECISION_ENV_BANDIT_RESULTS_KEY, None)
st.success("已清除自动探索结果。")
st.divider() st.divider()
st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。") st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。")
default_grid = "\n".join( default_grid = "\n".join(

View File

@ -60,6 +60,7 @@ class DummyEnv:
"weights": {"A_mom": value}, "weights": {"A_mom": value},
"risk_breakdown": metrics.risk_breakdown, "risk_breakdown": metrics.risk_breakdown,
"risk_events": [], "risk_events": [],
"department_controls": {"momentum": {"prompt": "baseline"}},
} }
return obs, reward, True, info return obs, reward, True, info
@ -92,3 +93,8 @@ def test_bandit_optimizer_runs_and_logs(patch_logging):
payload = patch_logging[0]["metrics"] payload = patch_logging[0]["metrics"]
assert isinstance(payload, dict) assert isinstance(payload, dict)
assert "risk_breakdown" in payload assert "risk_breakdown" in payload
assert "department_controls" in payload
first_episode = summary.episodes[0]
assert first_episode.resolved_action
assert first_episode.department_controls == {"momentum": {"prompt": "baseline"}}