update
This commit is contained in:
parent
d0a0340db6
commit
fa46be501b
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, Iterable, List, Sequence, Tuple
|
from typing import Any, Dict, Iterable, List, Mapping, Sequence, Tuple
|
||||||
|
|
||||||
from app.backtest.decision_env import DecisionEnv, EpisodeMetrics
|
from app.backtest.decision_env import DecisionEnv, EpisodeMetrics
|
||||||
from app.backtest.decision_env import ParameterSpec
|
from app.backtest.decision_env import ParameterSpec
|
||||||
@ -28,9 +28,12 @@ class BanditConfig:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BanditEpisode:
|
class BanditEpisode:
|
||||||
action: Dict[str, float]
|
action: Dict[str, float]
|
||||||
|
resolved_action: Dict[str, Any]
|
||||||
reward: float
|
reward: float
|
||||||
metrics: EpisodeMetrics
|
metrics: EpisodeMetrics
|
||||||
observation: Dict[str, float]
|
observation: Dict[str, float]
|
||||||
|
weights: Mapping[str, float] | None = None
|
||||||
|
department_controls: Mapping[str, Mapping[str, Any]] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -78,8 +81,13 @@ class EpsilonGreedyBandit:
|
|||||||
self._counts[key] = count
|
self._counts[key] = count
|
||||||
self._value_estimates[key] = old_estimate + (reward - old_estimate) / count
|
self._value_estimates[key] = old_estimate + (reward - old_estimate) / count
|
||||||
|
|
||||||
action_payload = self._action_to_mapping(action)
|
action_payload = self._raw_action_mapping(action)
|
||||||
|
resolved_action = self._resolved_action_mapping(action)
|
||||||
metrics_payload = _metrics_to_dict(metrics)
|
metrics_payload = _metrics_to_dict(metrics)
|
||||||
|
department_controls = info.get("department_controls")
|
||||||
|
if department_controls:
|
||||||
|
metrics_payload["department_controls"] = department_controls
|
||||||
|
metrics_payload["resolved_action"] = resolved_action
|
||||||
try:
|
try:
|
||||||
log_tuning_result(
|
log_tuning_result(
|
||||||
experiment_id=self.config.experiment_id,
|
experiment_id=self.config.experiment_id,
|
||||||
@ -94,9 +102,12 @@ class EpsilonGreedyBandit:
|
|||||||
|
|
||||||
episode_record = BanditEpisode(
|
episode_record = BanditEpisode(
|
||||||
action=action_payload,
|
action=action_payload,
|
||||||
|
resolved_action=resolved_action,
|
||||||
reward=reward,
|
reward=reward,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
observation=obs,
|
observation=obs,
|
||||||
|
weights=info.get("weights"),
|
||||||
|
department_controls=department_controls,
|
||||||
)
|
)
|
||||||
self._history.episodes.append(episode_record)
|
self._history.episodes.append(episode_record)
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
@ -112,17 +123,28 @@ class EpsilonGreedyBandit:
|
|||||||
if self._value_estimates and self._random.random() > self.config.epsilon:
|
if self._value_estimates and self._random.random() > self.config.epsilon:
|
||||||
best = max(self._value_estimates.items(), key=lambda item: item[1])[0]
|
best = max(self._value_estimates.items(), key=lambda item: item[1])[0]
|
||||||
return list(best)
|
return list(best)
|
||||||
return [
|
return [self._sample_value(spec) for spec in self._specs]
|
||||||
self._random.uniform(spec.minimum, spec.maximum)
|
|
||||||
for spec in self._specs
|
|
||||||
]
|
|
||||||
|
|
||||||
def _action_to_mapping(self, action: Sequence[float]) -> Dict[str, float]:
|
def _raw_action_mapping(self, action: Sequence[float]) -> Dict[str, float]:
|
||||||
return {
|
return {
|
||||||
spec.name: float(value)
|
spec.name: float(value)
|
||||||
for spec, value in zip(self._specs, action, strict=True)
|
for spec, value in zip(self._specs, action, strict=True)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _resolved_action_mapping(self, action: Sequence[float]) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
spec.name: spec.resolve(value)
|
||||||
|
for spec, value in zip(self._specs, action, strict=True)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _sample_value(self, spec: ParameterSpec) -> float:
|
||||||
|
if spec.values:
|
||||||
|
if len(spec.values) <= 1:
|
||||||
|
return 0.0
|
||||||
|
index = self._random.randrange(len(spec.values))
|
||||||
|
return index / (len(spec.values) - 1)
|
||||||
|
return self._random.random()
|
||||||
|
|
||||||
|
|
||||||
def _metrics_to_dict(metrics: EpisodeMetrics) -> Dict[str, float | Dict[str, int]]:
|
def _metrics_to_dict(metrics: EpisodeMetrics) -> Dict[str, float | Dict[str, int]]:
|
||||||
payload: Dict[str, float | Dict[str, int]] = {
|
payload: Dict[str, float | Dict[str, int]] = {
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from app.agents.base import AgentContext
|
|||||||
from app.agents.game import Decision
|
from app.agents.game import Decision
|
||||||
from app.agents.registry import default_agents
|
from app.agents.registry import default_agents
|
||||||
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
from app.backtest.decision_env import DecisionEnv, ParameterSpec
|
||||||
|
from app.backtest.optimizer import BanditConfig, EpsilonGreedyBandit
|
||||||
from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig, run_backtest
|
from app.backtest.engine import BacktestEngine, PortfolioState, BtConfig, run_backtest
|
||||||
from app.ingest.checker import run_boot_check
|
from app.ingest.checker import run_boot_check
|
||||||
from app.ingest.tushare import run_ingestion
|
from app.ingest.tushare import run_ingestion
|
||||||
@ -35,6 +36,7 @@ from app.ui.views.dashboard import update_dashboard_sidebar
|
|||||||
|
|
||||||
_DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result"
|
_DECISION_ENV_SINGLE_RESULT_KEY = "decision_env_single_result"
|
||||||
_DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
|
_DECISION_ENV_BATCH_RESULTS_KEY = "decision_env_batch_results"
|
||||||
|
_DECISION_ENV_BANDIT_RESULTS_KEY = "decision_env_bandit_results"
|
||||||
|
|
||||||
def render_backtest_review() -> None:
|
def render_backtest_review() -> None:
|
||||||
"""渲染回测执行、调参与结果复盘页面。"""
|
"""渲染回测执行、调参与结果复盘页面。"""
|
||||||
@ -675,6 +677,170 @@ def render_backtest_review() -> None:
|
|||||||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||||||
st.success("已清除单次调参结果缓存。")
|
st.success("已清除单次调参结果缓存。")
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
st.subheader("自动探索(epsilon-greedy)")
|
||||||
|
col_ep, col_eps, col_seed = st.columns([1, 1, 1])
|
||||||
|
bandit_episodes = int(
|
||||||
|
col_ep.number_input(
|
||||||
|
"迭代次数",
|
||||||
|
min_value=1,
|
||||||
|
max_value=200,
|
||||||
|
value=10,
|
||||||
|
step=1,
|
||||||
|
key="decision_env_bandit_episodes",
|
||||||
|
help="探索的回合数,越大越充分但耗时越久。",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
bandit_epsilon = float(
|
||||||
|
col_eps.slider(
|
||||||
|
"探索比例 ε",
|
||||||
|
min_value=0.0,
|
||||||
|
max_value=1.0,
|
||||||
|
value=0.2,
|
||||||
|
step=0.05,
|
||||||
|
key="decision_env_bandit_epsilon",
|
||||||
|
help="ε 越大,随机探索概率越高。",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seed_text = col_seed.text_input(
|
||||||
|
"随机种子(可选)",
|
||||||
|
value="",
|
||||||
|
key="decision_env_bandit_seed",
|
||||||
|
help="填写整数可复现实验,不填写则随机。",
|
||||||
|
).strip()
|
||||||
|
bandit_seed = None
|
||||||
|
if seed_text:
|
||||||
|
try:
|
||||||
|
bandit_seed = int(seed_text)
|
||||||
|
except ValueError:
|
||||||
|
st.warning("随机种子需为整数,已忽略该值。")
|
||||||
|
bandit_seed = None
|
||||||
|
|
||||||
|
run_bandit = st.button("执行自动探索", key="run_decision_env_bandit")
|
||||||
|
if run_bandit:
|
||||||
|
if not specs:
|
||||||
|
st.warning("请至少配置一个动作维度再执行探索。")
|
||||||
|
elif selected_agents and not range_valid:
|
||||||
|
st.error("请确保所有代理的最大权重大于最小权重。")
|
||||||
|
elif not controls_valid:
|
||||||
|
st.error("请修正部门参数的取值范围。")
|
||||||
|
else:
|
||||||
|
baseline_weights = cfg.agent_weights.as_dict()
|
||||||
|
for agent in agent_objects:
|
||||||
|
baseline_weights.setdefault(agent.name, 1.0)
|
||||||
|
|
||||||
|
universe_env = [code.strip() for code in universe_text.split(',') if code.strip()]
|
||||||
|
if not universe_env:
|
||||||
|
st.error("请先指定至少一个股票代码。")
|
||||||
|
else:
|
||||||
|
bt_cfg_env = BtConfig(
|
||||||
|
id="decision_env_bandit",
|
||||||
|
name="DecisionEnv Bandit",
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
universe=universe_env,
|
||||||
|
params={
|
||||||
|
"target": target,
|
||||||
|
"stop": stop,
|
||||||
|
"hold_days": int(hold_days),
|
||||||
|
},
|
||||||
|
method=cfg.decision_method,
|
||||||
|
)
|
||||||
|
env = DecisionEnv(
|
||||||
|
bt_config=bt_cfg_env,
|
||||||
|
parameter_specs=specs,
|
||||||
|
baseline_weights=baseline_weights,
|
||||||
|
disable_departments=disable_departments,
|
||||||
|
)
|
||||||
|
config = BanditConfig(
|
||||||
|
experiment_id=experiment_id or f"bandit_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||||
|
strategy=strategy_label or "DecisionEnv",
|
||||||
|
episodes=bandit_episodes,
|
||||||
|
epsilon=bandit_epsilon,
|
||||||
|
seed=bandit_seed,
|
||||||
|
)
|
||||||
|
bandit = EpsilonGreedyBandit(env, config)
|
||||||
|
with st.spinner("自动探索进行中,请稍候..."):
|
||||||
|
summary = bandit.run()
|
||||||
|
|
||||||
|
episodes_dump: List[Dict[str, object]] = []
|
||||||
|
for idx, episode in enumerate(summary.episodes, start=1):
|
||||||
|
episodes_dump.append(
|
||||||
|
{
|
||||||
|
"序号": idx,
|
||||||
|
"奖励": episode.reward,
|
||||||
|
"动作(raw)": json.dumps(episode.action, ensure_ascii=False),
|
||||||
|
"参数值": json.dumps(episode.resolved_action, ensure_ascii=False),
|
||||||
|
"总收益": episode.metrics.total_return,
|
||||||
|
"最大回撤": episode.metrics.max_drawdown,
|
||||||
|
"波动率": episode.metrics.volatility,
|
||||||
|
"权重": json.dumps(episode.weights or {}, ensure_ascii=False),
|
||||||
|
"部门控制": json.dumps(episode.department_controls or {}, ensure_ascii=False),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
best_episode = summary.best_episode
|
||||||
|
best_index = summary.episodes.index(best_episode) + 1 if best_episode else None
|
||||||
|
st.session_state[_DECISION_ENV_BANDIT_RESULTS_KEY] = {
|
||||||
|
"episodes": episodes_dump,
|
||||||
|
"best_index": best_index,
|
||||||
|
"best": {
|
||||||
|
"reward": best_episode.reward if best_episode else None,
|
||||||
|
"action": best_episode.action if best_episode else None,
|
||||||
|
"resolved_action": best_episode.resolved_action if best_episode else None,
|
||||||
|
"weights": best_episode.weights if best_episode else None,
|
||||||
|
"department_controls": best_episode.department_controls if best_episode else None,
|
||||||
|
},
|
||||||
|
"experiment_id": config.experiment_id,
|
||||||
|
"strategy": config.strategy,
|
||||||
|
}
|
||||||
|
st.success(f"自动探索完成,共执行 {len(episodes_dump)} 轮。")
|
||||||
|
|
||||||
|
bandit_state = st.session_state.get(_DECISION_ENV_BANDIT_RESULTS_KEY)
|
||||||
|
if bandit_state:
|
||||||
|
st.caption(
|
||||||
|
f"实验 ID:{bandit_state.get('experiment_id')} | 策略:{bandit_state.get('strategy')}"
|
||||||
|
)
|
||||||
|
episodes_dump = bandit_state.get("episodes") or []
|
||||||
|
if episodes_dump:
|
||||||
|
st.dataframe(pd.DataFrame(episodes_dump), hide_index=True, width='stretch')
|
||||||
|
best_payload = bandit_state.get("best") or {}
|
||||||
|
if best_payload.get("reward") is not None:
|
||||||
|
st.success(
|
||||||
|
f"最佳结果:第 {bandit_state.get('best_index')} 轮,奖励 {best_payload['reward']:+.4f}"
|
||||||
|
)
|
||||||
|
col_best1, col_best2 = st.columns(2)
|
||||||
|
col_best1.write("动作(raw):")
|
||||||
|
col_best1.json(best_payload.get("action") or {})
|
||||||
|
col_best2.write("参数值:")
|
||||||
|
col_best2.json(best_payload.get("resolved_action") or {})
|
||||||
|
weights_payload = best_payload.get("weights") or {}
|
||||||
|
if weights_payload:
|
||||||
|
st.write("对应代理权重:")
|
||||||
|
st.json(weights_payload)
|
||||||
|
if st.button(
|
||||||
|
"将最佳权重写入默认配置",
|
||||||
|
key="save_decision_env_bandit_weights",
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
cfg.agent_weights.update_from_dict(weights_payload)
|
||||||
|
save_config(cfg)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.exception(
|
||||||
|
"保存 bandit 权重失败",
|
||||||
|
extra={**LOG_EXTRA, "error": str(exc)},
|
||||||
|
)
|
||||||
|
st.error(f"写入配置失败:{exc}")
|
||||||
|
else:
|
||||||
|
st.success("最佳权重已写入 config.json")
|
||||||
|
dept_ctrl = best_payload.get("department_controls") or {}
|
||||||
|
if dept_ctrl:
|
||||||
|
with st.expander("最佳部门控制参数", expanded=False):
|
||||||
|
st.json(dept_ctrl)
|
||||||
|
if st.button("清除自动探索结果", key="clear_decision_env_bandit"):
|
||||||
|
st.session_state.pop(_DECISION_ENV_BANDIT_RESULTS_KEY, None)
|
||||||
|
st.success("已清除自动探索结果。")
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。")
|
st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。")
|
||||||
default_grid = "\n".join(
|
default_grid = "\n".join(
|
||||||
|
|||||||
@ -60,6 +60,7 @@ class DummyEnv:
|
|||||||
"weights": {"A_mom": value},
|
"weights": {"A_mom": value},
|
||||||
"risk_breakdown": metrics.risk_breakdown,
|
"risk_breakdown": metrics.risk_breakdown,
|
||||||
"risk_events": [],
|
"risk_events": [],
|
||||||
|
"department_controls": {"momentum": {"prompt": "baseline"}},
|
||||||
}
|
}
|
||||||
return obs, reward, True, info
|
return obs, reward, True, info
|
||||||
|
|
||||||
@ -92,3 +93,8 @@ def test_bandit_optimizer_runs_and_logs(patch_logging):
|
|||||||
payload = patch_logging[0]["metrics"]
|
payload = patch_logging[0]["metrics"]
|
||||||
assert isinstance(payload, dict)
|
assert isinstance(payload, dict)
|
||||||
assert "risk_breakdown" in payload
|
assert "risk_breakdown" in payload
|
||||||
|
assert "department_controls" in payload
|
||||||
|
|
||||||
|
first_episode = summary.episodes[0]
|
||||||
|
assert first_episode.resolved_action
|
||||||
|
assert first_episode.department_controls == {"momentum": {"prompt": "baseline"}}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user