diff --git a/app/backtest/decision_env.py b/app/backtest/decision_env.py new file mode 100644 index 0000000..6577e3a --- /dev/null +++ b/app/backtest/decision_env.py @@ -0,0 +1,180 @@ +"""Reinforcement-learning style environment wrapping the backtest engine.""" +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple + +import math + +from .engine import BacktestEngine, BacktestResult, BtConfig +from app.agents.game import Decision +from app.agents.registry import weight_map +from app.utils.logging import get_logger + +LOGGER = get_logger(__name__) +LOG_EXTRA = {"stage": "decision_env"} + + +@dataclass(frozen=True) +class ParameterSpec: + """Defines how a scalar action dimension maps to strategy parameters.""" + + name: str + target: str + minimum: float = 0.0 + maximum: float = 1.0 + + def clamp(self, value: float) -> float: + clipped = max(0.0, min(1.0, float(value))) + return self.minimum + clipped * (self.maximum - self.minimum) + + +@dataclass +class EpisodeMetrics: + total_return: float + max_drawdown: float + volatility: float + nav_series: List[Dict[str, float]] + trades: List[Dict[str, object]] + + @property + def sharpe_like(self) -> float: + if self.volatility <= 1e-9: + return 0.0 + return self.total_return / self.volatility + + +class DecisionEnv: + """Thin RL-friendly wrapper that evaluates parameter actions via backtest.""" + + def __init__( + self, + *, + bt_config: BtConfig, + parameter_specs: Sequence[ParameterSpec], + baseline_weights: Mapping[str, float], + reward_fn: Optional[Callable[[EpisodeMetrics], float]] = None, + disable_departments: bool = False, + ) -> None: + self._template_cfg = bt_config + self._specs = list(parameter_specs) + self._baseline_weights = dict(baseline_weights) + self._reward_fn = reward_fn or self._default_reward + self._last_metrics: Optional[EpisodeMetrics] = None + self._last_action: Optional[Tuple[float, ...]] = None + self._episode = 0 + self._disable_departments = bool(disable_departments) + + @property + def action_dim(self) -> int: + return len(self._specs) + + def reset(self) -> Dict[str, float]: + self._episode += 1 + self._last_metrics = None + self._last_action = None + return { + "episode": float(self._episode), + "baseline_return": 0.0, + } + + def step(self, action: Sequence[float]) -> Tuple[Dict[str, float], float, bool, Dict[str, object]]: + if len(action) != self.action_dim: + raise ValueError(f"expected action length {self.action_dim}, got {len(action)}") + action_array = [float(val) for val in action] + self._last_action = tuple(action_array) + + weights = self._build_weights(action_array) + LOGGER.info("episode=%s action=%s weights=%s", self._episode, action_array, weights, extra=LOG_EXTRA) + + cfg = replace(self._template_cfg) + engine = BacktestEngine(cfg) + engine.weights = weight_map(weights) + if self._disable_departments: + engine.department_manager = None + + try: + result = engine.run() + except Exception as exc: # noqa: BLE001 + LOGGER.exception("backtest failed under action", extra={**LOG_EXTRA, "error": str(exc)}) + info = {"error": str(exc)} + return {"failure": 1.0}, -1.0, True, info + + metrics = self._compute_metrics(result) + reward = float(self._reward_fn(metrics)) + self._last_metrics = metrics + + observation = { + "total_return": metrics.total_return, + "max_drawdown": metrics.max_drawdown, + "volatility": metrics.volatility, + "sharpe_like": metrics.sharpe_like, + } + info = { + "nav_series": metrics.nav_series, + "trades": metrics.trades, + "weights": weights, + } + return observation, reward, True, info + + def _build_weights(self, action: Sequence[float]) -> Dict[str, float]: + weights = dict(self._baseline_weights) + for idx, spec in enumerate(self._specs): + value = spec.clamp(action[idx]) + if spec.target.startswith("agent_weights."): + agent_name = spec.target.split(".", 1)[1] + weights[agent_name] = value + else: + LOGGER.debug("暂未支持的参数目标:%s", spec.target, extra=LOG_EXTRA) + return weights + + def _compute_metrics(self, result: BacktestResult) -> EpisodeMetrics: + nav_series = result.nav_series or [] + if not nav_series: + return EpisodeMetrics(0.0, 0.0, 0.0, [], result.trades) + + nav_values = [row.get("nav", 0.0) for row in nav_series] + if not nav_values or nav_values[0] == 0: + base_nav = nav_values[0] if nav_values else 1.0 + else: + base_nav = nav_values[0] + + returns = [(nav / base_nav) - 1.0 for nav in nav_values] + total_return = returns[-1] + + peak = nav_values[0] + max_drawdown = 0.0 + for nav in nav_values: + if nav > peak: + peak = nav + drawdown = (peak - nav) / peak if peak else 0.0 + max_drawdown = max(max_drawdown, drawdown) + + diffs = [nav_values[idx] - nav_values[idx - 1] for idx in range(1, len(nav_values))] + if diffs: + mean_diff = sum(diffs) / len(diffs) + variance = sum((diff - mean_diff) ** 2 for diff in diffs) / len(diffs) + volatility = math.sqrt(variance) / base_nav + else: + volatility = 0.0 + + return EpisodeMetrics( + total_return=float(total_return), + max_drawdown=float(max_drawdown), + volatility=volatility, + nav_series=nav_series, + trades=result.trades, + ) + + @staticmethod + def _default_reward(metrics: EpisodeMetrics) -> float: + penalty = 0.5 * metrics.max_drawdown + return metrics.total_return - penalty + + @property + def last_metrics(self) -> Optional[EpisodeMetrics]: + return self._last_metrics + + @property + def last_action(self) -> Optional[Tuple[float, ...]]: + return self._last_action diff --git a/app/data/schema.py b/app/data/schema.py index ecccf15..56ffcba 100644 --- a/app/data/schema.py +++ b/app/data/schema.py @@ -423,6 +423,18 @@ SCHEMA_STATEMENTS: Iterable[str] = ( notes TEXT, metadata TEXT ); + """, + """ + CREATE TABLE IF NOT EXISTS tuning_results ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + experiment_id TEXT, + strategy TEXT, + action TEXT, + weights TEXT, + reward REAL, + metrics TEXT, + created_at TEXT DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')) + ); """ ) @@ -456,6 +468,7 @@ REQUIRED_TABLES = ( "portfolio_positions", "portfolio_trades", "portfolio_snapshots", + "tuning_results", ) diff --git a/app/ui/streamlit_app.py b/app/ui/streamlit_app.py index 0641b15..8ada8ad 100644 --- a/app/ui/streamlit_app.py +++ b/app/ui/streamlit_app.py @@ -13,6 +13,8 @@ if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) import json +from datetime import datetime +import uuid import pandas as pd import plotly.express as px @@ -24,6 +26,7 @@ import streamlit as st from app.agents.base import AgentContext from app.agents.game import Decision from app.backtest.engine import BtConfig, run_backtest +from app.backtest.decision_env import DecisionEnv, ParameterSpec from app.data.schema import initialize_database from app.ingest.checker import run_boot_check from app.ingest.tushare import FetchJob, run_ingestion @@ -53,6 +56,8 @@ from app.utils.portfolio import ( list_positions, list_recent_trades, ) +from app.agents.registry import default_agents +from app.utils.tuning import log_tuning_result LOGGER = get_logger(__name__) @@ -623,6 +628,7 @@ def render_backtest() -> None: st.header("回测与复盘") st.write("在此运行回测、展示净值曲线与代理贡献。") + cfg = get_config() default_start, default_end = _default_backtest_range(window_days=60) LOGGER.debug( "回测默认参数:start=%s end=%s universe=%s target=%s stop=%s hold_days=%s", @@ -746,6 +752,347 @@ def render_backtest() -> None: status_box.update(label="回测执行失败", state="error") st.error(f"回测执行失败:{exc}") + with st.expander("离线调参实验 (DecisionEnv)", expanded=False): + st.caption( + "使用 DecisionEnv 对代理权重做离线调参。请选择需要优化的代理并设定权重范围," + "系统会运行一次回测并返回收益、回撤等指标。若 LLM 网络不可用,将返回失败标记。" + ) + + disable_departments = st.checkbox( + "禁用部门 LLM(仅规则代理,适合离线快速评估)", + value=True, + help="关闭部门调用后不依赖外部 LLM 网络,仅根据规则代理权重模拟。", + ) + + default_experiment_id = f"streamlit_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + experiment_id = st.text_input( + "实验 ID", + value=default_experiment_id, + help="用于在 tuning_results 表中区分不同实验。", + ) + strategy_label = st.text_input( + "策略说明", + value="DecisionEnv", + help="可选:为本次调参记录一个策略名称或备注。", + ) + + agent_objects = default_agents() + agent_names = [agent.name for agent in agent_objects] + if not agent_names: + st.info("暂无可调整的代理。") + else: + selected_agents = st.multiselect( + "选择调参的代理权重", + agent_names, + default=agent_names[:2], + key="decision_env_agents", + ) + + specs: List[ParameterSpec] = [] + action_values: List[float] = [] + range_valid = True + for idx, agent_name in enumerate(selected_agents): + col_min, col_max, col_action = st.columns([1, 1, 2]) + min_key = f"decision_env_min_{agent_name}" + max_key = f"decision_env_max_{agent_name}" + action_key = f"decision_env_action_{agent_name}" + default_min = 0.0 + default_max = 1.0 + min_val = col_min.number_input( + f"{agent_name} 最小权重", + min_value=0.0, + max_value=1.0, + value=default_min, + step=0.05, + key=min_key, + ) + max_val = col_max.number_input( + f"{agent_name} 最大权重", + min_value=0.0, + max_value=1.0, + value=default_max, + step=0.05, + key=max_key, + ) + if max_val <= min_val: + range_valid = False + action_val = col_action.slider( + f"{agent_name} 动作 (0-1)", + min_value=0.0, + max_value=1.0, + value=0.5, + step=0.01, + key=action_key, + ) + specs.append( + ParameterSpec( + name=f"weight_{agent_name}", + target=f"agent_weights.{agent_name}", + minimum=min_val, + maximum=max_val, + ) + ) + action_values.append(action_val) + + run_decision_env = st.button("执行单次调参", key="run_decision_env_button") + if run_decision_env: + if not selected_agents: + st.warning("请至少选择一个代理进行调参。") + elif not range_valid: + st.error("请确保所有代理的最大权重大于最小权重。") + else: + baseline_weights = cfg.agent_weights.as_dict() + for agent in agent_objects: + baseline_weights.setdefault(agent.name, 1.0) + + universe_env = [code.strip() for code in universe_text.split(',') if code.strip()] + if not universe_env: + st.error("请先指定至少一个股票代码。") + else: + bt_cfg_env = BtConfig( + id="decision_env_streamlit", + name="DecisionEnv Streamlit", + start_date=start_date, + end_date=end_date, + universe=universe_env, + params={ + "target": target, + "stop": stop, + "hold_days": int(hold_days), + }, + method=cfg.decision_method, + ) + env = DecisionEnv( + bt_config=bt_cfg_env, + parameter_specs=specs, + baseline_weights=baseline_weights, + disable_departments=disable_departments, + ) + env.reset() + with st.spinner("正在执行离线调参……"): + try: + observation, reward, done, info = env.step(action_values) + except Exception as exc: # noqa: BLE001 + LOGGER.exception("DecisionEnv 调用失败", extra=LOG_EXTRA) + st.error(f"离线调参失败:{exc}") + else: + if observation.get("failure"): + st.error("调参失败:回测执行未完成,可能是 LLM 网络不可用或参数异常。") + st.json(observation) + else: + st.success("离线调参完成") + col_metrics = st.columns(4) + col_metrics[0].metric("总收益", f"{observation.get('total_return', 0.0):+.2%}") + col_metrics[1].metric("最大回撤", f"{observation.get('max_drawdown', 0.0):+.2%}") + col_metrics[2].metric("波动率", f"{observation.get('volatility', 0.0):+.2%}") + col_metrics[3].metric("奖励", f"{reward:+.4f}") + + st.write("调参后权重:") + weights_dict = info.get("weights", {}) + st.json(weights_dict) + action_payload = { + name: value + for name, value in zip(selected_agents, action_values) + } + metrics_payload = dict(observation) + metrics_payload["reward"] = reward + try: + log_tuning_result( + experiment_id=experiment_id or str(uuid.uuid4()), + strategy=strategy_label or "DecisionEnv", + action=action_payload, + reward=reward, + metrics=metrics_payload, + weights=weights_dict, + ) + st.caption("调参结果已写入 tuning_results 表。") + except Exception: # noqa: BLE001 + LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA) + + if weights_dict: + if st.button( + "保存这些权重为默认配置", + key="save_decision_env_weights_single", + ): + cfg.agent_weights.update_from_dict(weights_dict) + save_config(cfg) + st.success("代理权重已写入 config.json") + + nav_series = info.get("nav_series") + if nav_series: + try: + nav_df = pd.DataFrame(nav_series) + if {"trade_date", "nav"}.issubset(nav_df.columns): + nav_df = nav_df.sort_values("trade_date") + nav_df["trade_date"] = pd.to_datetime(nav_df["trade_date"]) + st.line_chart(nav_df.set_index("trade_date")["nav"], height=220) + except Exception: # noqa: BLE001 + LOGGER.debug("导航曲线绘制失败", extra=LOG_EXTRA) + trades = info.get("trades") + if trades: + st.write("成交记录:") + st.dataframe(pd.DataFrame(trades), hide_index=True, width='stretch') + + st.divider() + st.caption("批量调参:在下方输入多组动作,每行表示一组 0-1 之间的值,用逗号分隔。") + default_grid = "\n".join( + [ + ",".join(["0.2" for _ in specs]), + ",".join(["0.5" for _ in specs]), + ",".join(["0.8" for _ in specs]), + ] + ) if specs else "" + action_grid_raw = st.text_area( + "动作列表", + value=default_grid, + height=120, + key="decision_env_batch_actions", + ) + run_batch = st.button("批量执行调参", key="run_decision_env_batch") + if run_batch: + if not selected_agents: + st.warning("请先选择调参代理。") + elif not range_valid: + st.error("请确保所有代理的最大权重大于最小权重。") + else: + lines = [line.strip() for line in action_grid_raw.splitlines() if line.strip()] + if not lines: + st.warning("请在文本框中输入至少一组动作。") + else: + parsed_actions: List[List[float]] = [] + for line in lines: + try: + values = [float(val.strip()) for val in line.split(',') if val.strip()] + except ValueError: + st.error(f"无法解析动作行:{line}") + parsed_actions = [] + break + if len(values) != len(specs): + st.error(f"动作维度不匹配(期望 {len(specs)} 个值):{line}") + parsed_actions = [] + break + parsed_actions.append(values) + if parsed_actions: + baseline_weights = cfg.agent_weights.as_dict() + for agent in agent_objects: + baseline_weights.setdefault(agent.name, 1.0) + + universe_env = [code.strip() for code in universe_text.split(',') if code.strip()] + if not universe_env: + st.error("请先指定至少一个股票代码。") + else: + bt_cfg_env = BtConfig( + id="decision_env_streamlit_batch", + name="DecisionEnv Batch", + start_date=start_date, + end_date=end_date, + universe=universe_env, + params={ + "target": target, + "stop": stop, + "hold_days": int(hold_days), + }, + method=cfg.decision_method, + ) + env = DecisionEnv( + bt_config=bt_cfg_env, + parameter_specs=specs, + baseline_weights=baseline_weights, + disable_departments=disable_departments, + ) + results: List[Dict[str, object]] = [] + with st.spinner("正在批量执行调参……"): + for idx, action_vals in enumerate(parsed_actions, start=1): + env.reset() + try: + observation, reward, done, info = env.step(action_vals) + except Exception as exc: # noqa: BLE001 + LOGGER.exception("批量调参失败", extra=LOG_EXTRA) + results.append( + { + "序号": idx, + "动作": action_vals, + "状态": "error", + "错误": str(exc), + } + ) + continue + if observation.get("failure"): + results.append( + { + "序号": idx, + "动作": action_vals, + "状态": "failure", + "奖励": -1.0, + } + ) + else: + action_payload = { + name: value + for name, value in zip(selected_agents, action_vals) + } + metrics_payload = dict(observation) + metrics_payload["reward"] = reward + weights_payload = info.get("weights", {}) + try: + log_tuning_result( + experiment_id=experiment_id or str(uuid.uuid4()), + strategy=strategy_label or "DecisionEnv", + action=action_payload, + reward=reward, + metrics=metrics_payload, + weights=weights_payload, + ) + except Exception: # noqa: BLE001 + LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA) + results.append( + { + "序号": idx, + "动作": action_vals, + "状态": "ok", + "总收益": observation.get("total_return", 0.0), + "最大回撤": observation.get("max_drawdown", 0.0), + "波动率": observation.get("volatility", 0.0), + "奖励": reward, + "权重": weights_payload, + } + ) + if results: + st.write("批量调参结果:") + results_df = pd.DataFrame(results) + st.dataframe(results_df, hide_index=True, width='stretch') + selectable = [ + row + for row in results + if row.get("状态") == "ok" and row.get("权重") + ] + if selectable: + option_labels = [ + f"序号 {row['序号']} | 奖励 {row.get('奖励', 0.0):+.4f}" + for row in selectable + ] + selected_label = st.selectbox( + "选择要保存的记录", + option_labels, + key="decision_env_batch_select", + ) + selected_row = None + for label, row in zip(option_labels, selectable): + if label == selected_label: + selected_row = row + break + if selected_row and st.button( + "保存所选权重为默认配置", + key="save_decision_env_weights_batch", + ): + cfg.agent_weights.update_from_dict(selected_row.get("权重", {})) + save_config(cfg) + st.success( + f"已将序号 {selected_row['序号']} 的权重写入 config.json" + ) + else: + st.caption("暂无成功的结果可供保存。") + def render_settings() -> None: LOGGER.info("渲染设置页面", extra=LOG_EXTRA) diff --git a/app/utils/config.py b/app/utils/config.py index bad481b..653cc14 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field import json import os from pathlib import Path -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List, Mapping, Optional def _default_root() -> Path: @@ -48,6 +48,32 @@ class AgentWeights: "A_macro": self.macro, } + def update_from_dict(self, data: Mapping[str, float]) -> None: + mapping = { + "A_mom": "momentum", + "momentum": "momentum", + "A_val": "value", + "value": "value", + "A_news": "news", + "news": "news", + "A_liq": "liquidity", + "liquidity": "liquidity", + "A_macro": "macro", + "macro": "macro", + } + for key, attr in mapping.items(): + if key in data and data[key] is not None: + try: + setattr(self, attr, float(data[key])) + except (TypeError, ValueError): + continue + + @classmethod + def from_dict(cls, data: Mapping[str, float]) -> "AgentWeights": + inst = cls() + inst.update_from_dict(data) + return inst + DEFAULT_LLM_MODEL_OPTIONS: Dict[str, Dict[str, object]] = { "ollama": { "models": ["llama3", "phi3", "qwen2"], @@ -357,6 +383,10 @@ def _load_from_file(cfg: AppConfig) -> None: if "decision_method" in payload: cfg.decision_method = str(payload.get("decision_method") or cfg.decision_method) + weights_payload = payload.get("agent_weights") + if isinstance(weights_payload, dict): + cfg.agent_weights.update_from_dict(weights_payload) + legacy_profiles: Dict[str, Dict[str, object]] = {} legacy_routes: Dict[str, Dict[str, object]] = {} @@ -523,6 +553,7 @@ def save_config(cfg: AppConfig | None = None) -> None: "tushare_token": cfg.tushare_token, "force_refresh": cfg.force_refresh, "decision_method": cfg.decision_method, + "agent_weights": cfg.agent_weights.as_dict(), "llm": { "strategy": cfg.llm.strategy if cfg.llm.strategy in ALLOWED_LLM_STRATEGIES else "single", "majority_threshold": cfg.llm.majority_threshold, diff --git a/app/utils/tuning.py b/app/utils/tuning.py new file mode 100644 index 0000000..c4bd4c3 --- /dev/null +++ b/app/utils/tuning.py @@ -0,0 +1,42 @@ +"""Helpers for logging decision tuning experiments.""" +from __future__ import annotations + +import json +from typing import Any, Dict, Optional + +from .db import db_session +from .logging import get_logger + +LOGGER = get_logger(__name__) +LOG_EXTRA = {"stage": "tuning"} + + +def log_tuning_result( + *, + experiment_id: str, + strategy: str, + action: Dict[str, Any], + reward: float, + metrics: Dict[str, Any], + weights: Optional[Dict[str, float]] = None, +) -> None: + """Persist a tuning result into the SQLite table.""" + + try: + with db_session() as conn: + conn.execute( + """ + INSERT INTO tuning_results (experiment_id, strategy, action, weights, reward, metrics) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + experiment_id, + strategy, + json.dumps(action, ensure_ascii=False), + json.dumps(weights or {}, ensure_ascii=False), + float(reward), + json.dumps(metrics, ensure_ascii=False), + ), + ) + except Exception: # noqa: BLE001 + LOGGER.exception("记录调参结果失败", extra=LOG_EXTRA) diff --git a/docs/decision_optimization_notes.md b/docs/decision_optimization_notes.md index 216bef1..84b54f4 100644 --- a/docs/decision_optimization_notes.md +++ b/docs/decision_optimization_notes.md @@ -36,8 +36,11 @@ - Streamlit 侧边栏监听 `llm.metrics` 的实时事件,并以 ~0.75 秒节流频率刷新“系统监控”,既保证日志到达后快速更新,也避免刷屏造成 UI 闪烁。 - 新增投资管理数据层:SQLite 中创建 `investment_pool`、`portfolio_positions`、`portfolio_trades`、`portfolio_snapshots` 四张表;`app/utils/portfolio.py` 提供访问接口,今日计划页可实时展示候选池、持仓与成交。 - 回测引擎 `record_agent_state()` 现同步写入 `investment_pool`,将每日全局决策的置信度、部门标签与目标权重落库,作为后续提示参数调优与候选池管理的基础数据。 +- `app/backtest/decision_env.py` 引入 `DecisionEnv`,用单步 RL/Gym 风格接口封装回测:动作 → 权重映射 → 回测 → 奖励(收益 - 0.5×回撤),同时输出 NAV、交易与行动权重,方便与 Bandit/PPO 等算法对接。 +- Streamlit “回测与复盘” 页新增离线调参模块,可即点即用 DecisionEnv 对代理权重进行实验,并可视化收益、回撤、成交与权重结果,支持一键写入 `config.json` 成为新的默认权重。 +- 所有离线调参实验(单次/批量)都会存入 SQLite `tuning_results`,包含实验 ID、动作、奖励、指标与权重,便于后续分析与对比。 ## 下一阶段路线图 -- 将 `BacktestEngine` 封装为 `DecisionEnv`,让一次策略配置跑完整个回测周期并输出奖励、约束违例等指标。 -- 接入 Bandit/贝叶斯优化,对 Prompt 版本、部门权重、温度范围做离线搜索,利用新增的 snapshot/positions 数据衡量风险与收益。 -- 构建持仓/成交写入流程(回测与实时),确保 RL 训练能复原资金曲线、资金占用与调仓成本。 +- 在 `DecisionEnv` 中扩展动作映射(Prompt 版本、部门温度、function 调用策略等),把当前权重型动作升级为多参数协同调整。 +- 接入 Bandit/贝叶斯优化,对动作空间进行探索,并把 `portfolio_snapshots`、`portfolio_trades` 输出纳入奖励约束(收益、回撤、换手率)。 +- 构建持仓/成交写入流程的实时入口,使线上监控与离线调参共用同一数据源,支撑增量训练与策略回放。 diff --git a/scripts/run_decision_env_example.py b/scripts/run_decision_env_example.py new file mode 100644 index 0000000..42e2315 --- /dev/null +++ b/scripts/run_decision_env_example.py @@ -0,0 +1,52 @@ +"""Quick example of using DecisionEnv for weight tuning experiments.""" +from __future__ import annotations + +import json +from datetime import date, timedelta + +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +from app.backtest.decision_env import DecisionEnv, ParameterSpec +from app.backtest.engine import BtConfig +from app.agents.registry import default_agents +from app.utils.config import get_config + + +def main() -> None: + cfg = get_config() + agents = default_agents() + baseline_weights = {agent.name: cfg.agent_weights.as_dict().get(agent.name, 1.0) for agent in agents} + + today = date.today() + bt_cfg = BtConfig( + id="decision_env_example", + name="Decision Env Demo", + start_date=today - timedelta(days=60), + end_date=today, + universe=["000001.SZ"], + params={}, + method=cfg.decision_method, + ) + + specs = [ + ParameterSpec(name="momentum_weight", target="agent_weights.A_mom", minimum=0.1, maximum=0.6), + ParameterSpec(name="value_weight", target="agent_weights.A_val", minimum=0.1, maximum=0.4), + ] + + env = DecisionEnv(bt_config=bt_cfg, parameter_specs=specs, baseline_weights=baseline_weights) + env.reset() + observation, reward, done, info = env.step([0.5, 0.2]) + + print("Observation:", json.dumps(observation, ensure_ascii=False, indent=2)) + print("Reward:", reward) + print("Done:", done) + print("Weights:", json.dumps(info.get("weights", {}), ensure_ascii=False, indent=2)) + + +if __name__ == "__main__": + main()