From d0a0340db65ceaf9a7d477a835c03b3c2cd593ec Mon Sep 17 00:00:00 2001 From: sam Date: Mon, 6 Oct 2025 13:21:43 +0800 Subject: [PATCH] update --- app/agents/departments.py | 34 +++- app/backtest/decision_env.py | 175 ++++++++++++++++- app/llm/templates.py | 368 +++++++++++++++++++++++++++++------ app/ui/views/backtest.py | 137 +++++++++++-- docs/TODO.md | 2 +- tests/test_decision_env.py | 115 ++++++++++- tests/test_llm_templates.py | 6 + 7 files changed, 758 insertions(+), 79 deletions(-) diff --git a/app/agents/departments.py b/app/agents/departments.py index 22e671b..39d7024 100644 --- a/app/agents/departments.py +++ b/app/agents/departments.py @@ -93,6 +93,38 @@ class DepartmentAgent: self._resolver = resolver self._broker = DataBroker() self._max_rounds = 3 + self._tool_choice = "auto" + + @property + def max_rounds(self) -> int: + return self._max_rounds + + @max_rounds.setter + def max_rounds(self, value: Any) -> None: + try: + numeric = int(round(float(value))) + except (TypeError, ValueError): + raise ValueError("max_rounds must be numeric") from None + if numeric < 1: + numeric = 1 + if numeric > 6: + numeric = 6 + self._max_rounds = numeric + + @property + def tool_choice(self) -> str: + return self._tool_choice + + @tool_choice.setter + def tool_choice(self, value: Any) -> None: + if value is None: + self._tool_choice = "auto" + return + normalized = str(value).strip().lower() + allowed = {"auto", "none", "required"} + if normalized not in allowed: + raise ValueError(f"Unsupported tool choice: {value}") + self._tool_choice = normalized def _get_llm_config(self) -> LLMConfig: if self._resolver: @@ -159,7 +191,7 @@ class DepartmentAgent: primary_endpoint, messages, tools=tools, - tool_choice="auto", + tool_choice=self._tool_choice, ) except LLMError as exc: LOGGER.warning( diff --git a/app/backtest/decision_env.py b/app/backtest/decision_env.py index 06c3a21..fb44906 100644 --- a/app/backtest/decision_env.py +++ b/app/backtest/decision_env.py @@ -3,6 +3,7 @@ from __future__ import annotations import json import math +import copy from dataclasses import dataclass, replace from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple @@ -18,17 +19,27 @@ LOG_EXTRA = {"stage": "decision_env"} @dataclass(frozen=True) class ParameterSpec: - """Defines how a scalar action dimension maps to strategy parameters.""" + """Defines how an action dimension maps to strategy parameters or behaviors.""" name: str target: str minimum: float = 0.0 maximum: float = 1.0 + values: Optional[Sequence[Any]] = None def clamp(self, value: float) -> float: clipped = max(0.0, min(1.0, float(value))) return self.minimum + clipped * (self.maximum - self.minimum) + def resolve(self, value: float) -> Any: + if self.values is not None: + if not self.values: + raise ValueError(f"ParameterSpec {self.name} configured with empty values list") + clipped = max(0.0, min(1.0, float(value))) + index = int(round(clipped * (len(self.values) - 1))) + return self.values[index] + return self.clamp(value) + @dataclass class EpisodeMetrics: @@ -68,6 +79,7 @@ class DecisionEnv: self._reward_fn = reward_fn or self._default_reward self._last_metrics: Optional[EpisodeMetrics] = None self._last_action: Optional[Tuple[float, ...]] = None + self._last_department_controls: Optional[Dict[str, Dict[str, Any]]] = None self._episode = 0 self._disable_departments = bool(disable_departments) @@ -75,10 +87,15 @@ class DecisionEnv: def action_dim(self) -> int: return len(self._specs) + @property + def last_department_controls(self) -> Optional[Dict[str, Dict[str, Any]]]: + return self._last_department_controls + def reset(self) -> Dict[str, float]: self._episode += 1 self._last_metrics = None self._last_action = None + self._last_department_controls = None return { "episode": float(self._episode), "baseline_return": 0.0, @@ -90,14 +107,24 @@ class DecisionEnv: action_array = [float(val) for val in action] self._last_action = tuple(action_array) - weights = self._build_weights(action_array) - LOGGER.info("episode=%s action=%s weights=%s", self._episode, action_array, weights, extra=LOG_EXTRA) + weights, department_controls = self._prepare_actions(action_array) + LOGGER.info( + "episode=%s action=%s weights=%s controls=%s", + self._episode, + action_array, + weights, + department_controls, + extra=LOG_EXTRA, + ) cfg = replace(self._template_cfg) engine = BacktestEngine(cfg) engine.weights = weight_map(weights) if self._disable_departments: engine.department_manager = None + applied_controls: Dict[str, Dict[str, Any]] = {} + else: + applied_controls = self._apply_department_controls(engine, department_controls) self._clear_portfolio_records() @@ -135,19 +162,153 @@ class DecisionEnv: "risk_events": getattr(result, "risk_events", []), "portfolio_snapshots": snapshots, "portfolio_trades": trades_override, + "department_controls": applied_controls, } + self._last_department_controls = applied_controls return observation, reward, True, info - def _build_weights(self, action: Sequence[float]) -> Dict[str, float]: + def _prepare_actions( + self, + action: Sequence[float], + ) -> Tuple[Dict[str, float], Dict[str, Dict[str, Any]]]: weights = dict(self._baseline_weights) + department_controls: Dict[str, Dict[str, Any]] = {} for idx, spec in enumerate(self._specs): - value = spec.clamp(action[idx]) + try: + resolved = spec.resolve(action[idx]) + except ValueError as exc: + LOGGER.warning("参数 %s 解析失败:%s", spec.name, exc, extra=LOG_EXTRA) + continue if spec.target.startswith("agent_weights."): agent_name = spec.target.split(".", 1)[1] - weights[agent_name] = value + try: + weights[agent_name] = float(resolved) + except (TypeError, ValueError): + LOGGER.debug( + "spec %s produced non-numeric weight %s; skipping", + spec.name, + resolved, + extra=LOG_EXTRA, + ) + continue + if spec.target.startswith("department."): + target_path = spec.target.split(".")[1:] + if len(target_path) < 2: + LOGGER.debug("未识别的部门目标:%s", spec.target, extra=LOG_EXTRA) + continue + dept_code = target_path[0] + field = ".".join(target_path[1:]) + dept_controls = department_controls.setdefault(dept_code, {}) + dept_controls[field] = resolved + continue else: LOGGER.debug("暂未支持的参数目标:%s", spec.target, extra=LOG_EXTRA) - return weights + return weights, department_controls + + def _apply_department_controls( + self, + engine: BacktestEngine, + controls: Mapping[str, Mapping[str, Any]], + ) -> Dict[str, Dict[str, Any]]: + manager = getattr(engine, "department_manager", None) + if not manager or not getattr(manager, "agents", None): + return {} + + applied: Dict[str, Dict[str, Any]] = {} + for dept_code, payload in controls.items(): + agent = manager.agents.get(dept_code) + if not agent or not isinstance(payload, Mapping): + continue + + applied_fields: Dict[str, Any] = {} + + # Ensure mutable settings clone to avoid global side-effects + try: + original_settings = agent.settings + cloned_settings = replace(original_settings) + cloned_settings.llm = copy.deepcopy(original_settings.llm) + agent.settings = cloned_settings + except Exception as exc: # noqa: BLE001 + LOGGER.warning( + "复制部门 %s 配置失败:%s", + dept_code, + exc, + extra=LOG_EXTRA, + ) + continue + + for raw_field, value in payload.items(): + field = raw_field.lower() + if field == "function_policy": + field = "tool_choice" + if field in {"prompt", "instruction"}: + agent.settings.prompt = str(value) + applied_fields[field] = agent.settings.prompt + continue + if field == "description": + agent.settings.description = str(value) + applied_fields[field] = agent.settings.description + continue + if field in {"prompt_template_id", "prompt_template"}: + agent.settings.prompt_template_id = str(value) + applied_fields["prompt_template_id"] = agent.settings.prompt_template_id + continue + if field == "prompt_template_version": + agent.settings.prompt_template_version = str(value) + applied_fields["prompt_template_version"] = agent.settings.prompt_template_version + continue + if field in {"temperature", "llm.temperature"}: + try: + temperature = max(0.0, min(2.0, float(value))) + agent.settings.llm.primary.temperature = temperature + applied_fields["temperature"] = temperature + except (TypeError, ValueError): + LOGGER.debug( + "无效的温度值 %s for %s", + value, + dept_code, + extra=LOG_EXTRA, + ) + continue + if field in {"tool_choice", "tool_strategy"}: + try: + agent.tool_choice = value + applied_fields["tool_choice"] = agent.tool_choice + except ValueError: + LOGGER.debug( + "部门 %s 工具策略 %s 无效", + dept_code, + value, + extra=LOG_EXTRA, + ) + continue + if field == "max_rounds": + try: + agent.max_rounds = value + applied_fields["max_rounds"] = agent.max_rounds + except ValueError: + LOGGER.debug( + "部门 %s max_rounds %s 无效", + dept_code, + value, + extra=LOG_EXTRA, + ) + continue + if field == "prompt_template_override": + agent.settings.prompt = str(value) + applied_fields["prompt"] = agent.settings.prompt + continue + LOGGER.debug( + "部门 %s 未识别的控制项 %s", + dept_code, + raw_field, + extra=LOG_EXTRA, + ) + + if applied_fields: + applied[dept_code] = applied_fields + + return applied def _compute_metrics( self, diff --git a/app/llm/templates.py b/app/llm/templates.py index b008e97..e83d7d1 100644 --- a/app/llm/templates.py +++ b/app/llm/templates.py @@ -252,98 +252,352 @@ class TemplateRegistry: DEFAULT_TEMPLATES = { "department_base": { "name": "部门基础模板", - "description": "通用的部门分析提示模板", + "description": "所有部门通用的审慎分析提示词骨架", "template": """ -部门名称:{title} +部门:{title} 股票代码:{ts_code} 交易日:{trade_date} -角色说明:{description} -职责指令:{instruction} +【角色定位】 +- 角色说明:{description} +- 行动守则:{instruction} -【可用数据范围】 +【数据边界】 +- 可用字段: {data_scope} - -【核心特征】 +- 核心特征: {features} - -【市场背景】 +- 市场背景: {market_snapshot} - -【追加数据】 +- 追加数据: {supplements} -请基于以上数据给出该部门对当前股票的操作建议。输出必须是 JSON,字段如下: +【分析步骤】 +1. 判断信息是否充分,如不充分,请说明缺口并优先调用工具 `fetch_data`(仅限 `daily`、`daily_basic`)。 +2. 梳理 2-3 个关键支撑信号与潜在风险,确保基于提供的数据。 +3. 结合量化证据与限制条件,给出操作建议和信心来源,避免主观臆测。 + +【输出要求】 +仅返回一个 JSON 对象,不要添加额外文本: {{ "action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD", - "confidence": 0-1 之间的小数,表示信心, - "summary": "一句话概括理由", - "signals": ["详细要点", "..."], - "risks": ["风险点", "..."] + "confidence": 0-1 之间的小数, + "summary": "一句话结论", + "signals": ["关键支撑要点", "..."], + "risks": ["关键风险要点", "..."] }} - -如需额外数据,请调用工具 `fetch_data`,仅支持请求 `daily` 或 `daily_basic` 表。 -请严格返回单个 JSON 对象,不要添加额外文本。 +如需说明未完成的数据请求,请在 `risks` 或 `signals` 中明确。 """, "variables": [ - "title", "ts_code", "trade_date", "description", "instruction", - "data_scope", "features", "market_snapshot", "supplements" + "title", + "ts_code", + "trade_date", + "description", + "instruction", + "data_scope", + "features", + "market_snapshot", + "supplements", ], "required_context": [ - "ts_code", "trade_date", "features", "market_snapshot" + "ts_code", + "trade_date", + "features", + "market_snapshot", ], - "validation_rules": [ - "len(features) > 0", - "len(market_snapshot) > 0" - ] + "metadata": { + "category": "department", + "preset": "base", + }, }, "momentum_dept": { - "name": "动量研究部门", - "description": "专注于动量因子分析的部门模板", + "name": "动量研究部门模板", + "description": "围绕价格与量能动量的决策提示", "template": """ -部门名称:动量研究部门 +部门:动量研究部门 股票代码:{ts_code} 交易日:{trade_date} -角色说明:专注于分析股票价格动量、成交量动量和技术指标动量 -职责指令:重点关注以下方面: -1. 价格趋势强度和持续性 -2. 成交量配合度 -3. 技术指标背离 +【角色定位】 +- 专注价格动量、成交量共振与技术指标背离。 +- 保持纪律,识别趋势延续与反转风险。 -【可用数据范围】 +【研究重点】 +1. 多时间窗口动量是否同向? +2. 成交量是否验证价格走势? +3. 是否出现过热或背离信号? + +【数据边界】 +- 可用字段: {data_scope} - -【动量特征】 +- 动量特征: {features} - -【市场背景】 +- 市场背景: {market_snapshot} - -【追加数据】 +- 追加数据: {supplements} -请基于以上数据进行动量分析并给出操作建议。输出必须是 JSON,字段如下: -{{ - "action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD", - "confidence": 0-1 之间的小数,表示信心, - "summary": "一句话概括动量分析结论", - "signals": ["动量信号要点", "..."], - "risks": ["动量风险点", "..."] -}} +请沿用【部门基础模板】的分析步骤与输出要求,重点量化趋势动能和量价配合度。 """, "variables": [ - "ts_code", "trade_date", "data_scope", - "features", "market_snapshot", "supplements" + "ts_code", + "trade_date", + "data_scope", + "features", + "market_snapshot", + "supplements", ], "required_context": [ - "ts_code", "trade_date", "features", "market_snapshot" + "ts_code", + "trade_date", + "features", + "market_snapshot", ], - "validation_rules": [ - "len(features) > 0", - "'momentum' in ' '.join(features.keys()).lower()" - ] - } + "metadata": { + "category": "department", + "preset": "momentum", + }, + }, + "value_dept": { + "name": "价值评估部门模板", + "description": "衡量估值与盈利质量的提示词", + "template": """ +部门:价值评估部门 +股票代码:{ts_code} +交易日:{trade_date} + +【角色定位】 +- 关注估值分位、盈利质量与安全边际。 +- 从中期配置角度评价当前价格的性价比。 + +【研究重点】 +1. 历史及同业视角的估值位置。 +2. 盈利与分红的可持续性。 +3. 潜在的估值修复催化或压制因素。 + +【数据边界】 +- 可用字段: +{data_scope} +- 估值与质量特征: +{features} +- 市场背景: +{market_snapshot} +- 追加数据: +{supplements} + +请按照【部门基础模板】的分析步骤输出结论,并明确估值安全边际来源。 +""", + "variables": [ + "ts_code", + "trade_date", + "data_scope", + "features", + "market_snapshot", + "supplements", + ], + "required_context": [ + "ts_code", + "trade_date", + "features", + "market_snapshot", + ], + "metadata": { + "category": "department", + "preset": "value", + }, + }, + "news_dept": { + "name": "新闻情绪部门模板", + "description": "针对舆情热度与事件影响的提示词", + "template": """ +部门:新闻情绪部门 +股票代码:{ts_code} +交易日:{trade_date} + +【角色定位】 +- 监控舆情热度、事件驱动与短期情绪。 +- 评估新闻对价格波动的正负面影响。 + +【研究重点】 +1. 新闻情绪是否集中且持续? +2. 主题与行情是否匹配? +3. 情绪驱动的风险敞口。 + +【数据边界】 +- 可用字段: +{data_scope} +- 舆情特征: +{features} +- 市场背景: +{market_snapshot} +- 追加数据: +{supplements} + +请遵循【部门基础模板】的分析步骤,突出情绪驱动的力度与时效性。 +""", + "variables": [ + "ts_code", + "trade_date", + "data_scope", + "features", + "market_snapshot", + "supplements", + ], + "required_context": [ + "ts_code", + "trade_date", + "features", + "market_snapshot", + ], + "metadata": { + "category": "department", + "preset": "news", + }, + }, + "liquidity_dept": { + "name": "流动性评估部门模板", + "description": "衡量成交活跃度与执行成本的提示词", + "template": """ +部门:流动性评估部门 +股票代码:{ts_code} +交易日:{trade_date} + +【角色定位】 +- 评估成交活跃度、交易成本与可执行性。 +- 提醒潜在的流动性风险与仓位限制。 + +【研究重点】 +1. 当前成交量与历史均值的对比。 +2. 价量限制(涨跌停、停牌等)对执行的影响。 +3. 预估滑点与转手难度。 + +【数据边界】 +- 可用字段: +{data_scope} +- 流动性特征: +{features} +- 市场背景: +{market_snapshot} +- 追加数据: +{supplements} + +请遵循【部门基础模板】的分析步骤,重点描述执行可行性与仓位建议。 +""", + "variables": [ + "ts_code", + "trade_date", + "data_scope", + "features", + "market_snapshot", + "supplements", + ], + "required_context": [ + "ts_code", + "trade_date", + "features", + "market_snapshot", + ], + "metadata": { + "category": "department", + "preset": "liquidity", + }, + }, + "macro_dept": { + "name": "宏观研究部门模板", + "description": "宏观与行业景气度分析提示词", + "template": """ +部门:宏观研究部门 +股票代码:{ts_code} +交易日:{trade_date} + +【角色定位】 +- 追踪宏观周期、行业景气与相对强弱。 +- 评估宏观事件对该标的的方向性影响。 + +【研究重点】 +1. 行业相对大盘的表现与热点程度。 +2. 宏观/政策事件对行业或标的的指引。 +3. 需警惕的宏观风险与流动性环境。 + +【数据边界】 +- 可用字段: +{data_scope} +- 宏观特征: +{features} +- 市场背景: +{market_snapshot} +- 追加数据: +{supplements} + +请执行【部门基础模板】的分析步骤,并输出宏观驱动的信号与风险。 +""", + "variables": [ + "ts_code", + "trade_date", + "data_scope", + "features", + "market_snapshot", + "supplements", + ], + "required_context": [ + "ts_code", + "trade_date", + "features", + "market_snapshot", + ], + "metadata": { + "category": "department", + "preset": "macro", + }, + }, + "risk_dept": { + "name": "风险控制部门模板", + "description": "识别极端风险与限制条件的提示词", + "template": """ +部门:风险控制部门 +股票代码:{ts_code} +交易日:{trade_date} + +【角色定位】 +- 防范停牌、涨跌停、仓位与合规限制。 +- 必要时对高风险决策行使否决权。 + +【研究重点】 +1. 交易限制或异常波动情况。 +2. 仓位、集中度或风险指标是否触顶。 +3. 潜在的黑天鹅或执行障碍。 + +【数据边界】 +- 可用字段: +{data_scope} +- 风险特征: +{features} +- 市场背景: +{market_snapshot} +- 追加数据: +{supplements} + +请按照【部门基础模板】的分析步骤,必要时明确阻止交易的理由。 +""", + "variables": [ + "ts_code", + "trade_date", + "data_scope", + "features", + "market_snapshot", + "supplements", + ], + "required_context": [ + "ts_code", + "trade_date", + "features", + "market_snapshot", + ], + "metadata": { + "category": "department", + "preset": "risk", + }, + }, } diff --git a/app/ui/views/backtest.py b/app/ui/views/backtest.py index 7841c6d..2cf9359 100644 --- a/app/ui/views/backtest.py +++ b/app/ui/views/backtest.py @@ -331,6 +331,7 @@ def render_backtest_review() -> None: ) specs: List[ParameterSpec] = [] + spec_labels: List[str] = [] action_values: List[float] = [] range_valid = True for idx, agent_name in enumerate(selected_agents): @@ -374,15 +375,111 @@ def render_backtest_review() -> None: maximum=max_val, ) ) + spec_labels.append(f"agent:{agent_name}") action_values.append(action_val) + controls_valid = True + with st.expander("部门 LLM 参数", expanded=False): + dept_codes = sorted(app_cfg.departments.keys()) + if not dept_codes: + st.caption("当前未配置部门。") + else: + selected_departments = st.multiselect( + "选择需要调整的部门", + dept_codes, + default=[], + key="decision_env_departments", + ) + tool_policy_values = ["auto", "none", "required"] + for dept_code in selected_departments: + settings = app_cfg.departments.get(dept_code) + if not settings: + continue + st.subheader(f"部门:{settings.title or dept_code}") + base_temp = 0.2 + if settings.llm and settings.llm.primary and settings.llm.primary.temperature is not None: + base_temp = float(settings.llm.primary.temperature) + prefix = f"decision_env_dept_{dept_code}" + col_tmin, col_tmax, col_tslider = st.columns([1, 1, 2]) + temp_min = col_tmin.number_input( + "温度最小值", + min_value=0.0, + max_value=2.0, + value=max(0.0, base_temp - 0.3), + step=0.05, + key=f"{prefix}_temp_min", + ) + temp_max = col_tmax.number_input( + "温度最大值", + min_value=0.0, + max_value=2.0, + value=min(2.0, base_temp + 0.3), + step=0.05, + key=f"{prefix}_temp_max", + ) + if temp_max <= temp_min: + controls_valid = False + st.warning("温度最大值必须大于最小值。") + temp_max = min(2.0, temp_min + 0.01) + span = temp_max - temp_min + if span <= 0: + ratio_default = 0.0 + else: + clamped = min(max(base_temp, temp_min), temp_max) + ratio_default = (clamped - temp_min) / span + temp_action = col_tslider.slider( + "动作值(映射至温度区间)", + min_value=0.0, + max_value=1.0, + value=float(ratio_default), + step=0.01, + key=f"{prefix}_temp_action", + ) + specs.append( + ParameterSpec( + name=f"dept_temperature_{dept_code}", + target=f"department.{dept_code}.temperature", + minimum=temp_min, + maximum=temp_max, + ) + ) + spec_labels.append(f"department:{dept_code}:temperature") + action_values.append(temp_action) + + col_tool, col_hint = st.columns([1, 2]) + tool_choice = col_tool.selectbox( + "函数调用策略", + tool_policy_values, + index=tool_policy_values.index("auto"), + key=f"{prefix}_tool_choice", + ) + col_hint.caption("映射提示:0→auto,0.5→none,1→required。") + if len(tool_policy_values) > 1: + tool_value = tool_policy_values.index(tool_choice) / (len(tool_policy_values) - 1) + else: + tool_value = 0.0 + specs.append( + ParameterSpec( + name=f"dept_tool_{dept_code}", + target=f"department.{dept_code}.function_policy", + values=tool_policy_values, + ) + ) + spec_labels.append(f"department:{dept_code}:tool_choice") + action_values.append(tool_value) + + if specs: + st.caption("动作维度顺序:" + ",".join(spec_labels)) + run_decision_env = st.button("执行单次调参", key="run_decision_env_button") just_finished_single = False if run_decision_env: - if not selected_agents: - st.warning("请至少选择一个代理进行调参。") - elif not range_valid: + if not specs: + st.warning("请至少配置一个动作维度(代理或部门参数)。") + elif selected_agents and not range_valid: st.error("请确保所有代理的最大权重大于最小权重。") + elif not controls_valid: + st.error("请修正部门参数的取值范围。") else: LOGGER.info( "离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s", @@ -448,11 +545,11 @@ def render_backtest_review() -> None: resolved_experiment_id = experiment_id or str(uuid.uuid4()) resolved_strategy = strategy_label or "DecisionEnv" action_payload = { - name: value - for name, value in zip(selected_agents, action_values) + label: value for label, value in zip(spec_labels, action_values) } metrics_payload = dict(observation) metrics_payload["reward"] = reward + metrics_payload["department_controls"] = info.get("department_controls") log_success = False try: log_tuning_result( @@ -477,12 +574,14 @@ def render_backtest_review() -> None: "observation": dict(observation), "reward": float(reward), "weights": info.get("weights", {}), + "department_controls": info.get("department_controls"), + "actions": action_payload, "nav_series": info.get("nav_series"), "trades": info.get("trades"), "portfolio_snapshots": info.get("portfolio_snapshots"), "portfolio_trades": info.get("portfolio_trades"), "risk_breakdown": info.get("risk_breakdown"), - "selected_agents": list(selected_agents), + "spec_labels": list(spec_labels), "action_values": list(action_values), "experiment_id": resolved_experiment_id, "strategy_label": resolved_strategy, @@ -562,6 +661,16 @@ def render_backtest_review() -> None: with st.expander("风险事件统计", expanded=False): st.json(risk_breakdown) + department_info = single_result.get("department_controls") or {} + if department_info: + with st.expander("部门控制参数", expanded=False): + st.json(department_info) + + action_snapshot = single_result.get("actions") or {} + if action_snapshot: + with st.expander("动作明细", expanded=False): + st.json(action_snapshot) + if st.button("清除单次调参结果", key="clear_decision_env_single"): st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None) st.success("已清除单次调参结果缓存。") @@ -584,10 +693,12 @@ def render_backtest_review() -> None: run_batch = st.button("批量执行调参", key="run_decision_env_batch") batch_just_ran = False if run_batch: - if not selected_agents: - st.warning("请先选择调参代理。") - elif not range_valid: + if not specs: + st.warning("请至少配置一个动作维度。") + elif selected_agents and not range_valid: st.error("请确保所有代理的最大权重大于最小权重。") + elif not controls_valid: + st.error("请修正部门参数的取值范围。") else: LOGGER.info( "离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s", @@ -693,11 +804,12 @@ def render_backtest_review() -> None: extra=LOG_EXTRA, ) action_payload = { - name: value - for name, value in zip(selected_agents, action_vals) + label: value + for label, value in zip(spec_labels, action_vals) } metrics_payload = dict(observation) metrics_payload["reward"] = reward + metrics_payload["department_controls"] = info.get("department_controls") weights_payload = info.get("weights", {}) try: log_tuning_result( @@ -713,13 +825,14 @@ def render_backtest_review() -> None: results.append( { "序号": idx, - "动作": action_vals, + "动作": action_payload, "状态": "ok", "总收益": observation.get("total_return", 0.0), "最大回撤": observation.get("max_drawdown", 0.0), "波动率": observation.get("volatility", 0.0), "奖励": reward, "权重": weights_payload, + "部门控制": info.get("department_controls"), } ) st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = { diff --git a/docs/TODO.md b/docs/TODO.md index 241e874..4670b95 100644 --- a/docs/TODO.md +++ b/docs/TODO.md @@ -18,7 +18,7 @@ - 围绕动量、估值、流动性等核心信号扩展轻量高质量因子集,全部由程序生成,满足端到端自动决策需求。 ## 3. 决策优化与强化学习 -- 扩展 `DecisionEnv` 的动作空间(提示版本、部门温度、function 调用策略等),不仅限于代理权重调节。 +- ✅ 扩展 `DecisionEnv` 的动作空间(提示版本、部门温度、function 调用策略等),不仅限于代理权重调节。 - 引入 Bandit / 贝叶斯优化或 RL 算法探索动作空间,并将 `portfolio_snapshots`、`portfolio_trades` 指标纳入奖励约束。 - 构建实时持仓/成交数据写入链路,使线上监控与离线调参共用同一数据源。 - 借鉴 TradingAgents-CN 的做法:拆分环境与策略、提供训练脚本/配置,并输出丰富的评估指标(如 Sharpe、Sortino、基准对比)。 diff --git a/tests/test_decision_env.py b/tests/test_decision_env.py index 18e5724..d87d1f7 100644 --- a/tests/test_decision_env.py +++ b/tests/test_decision_env.py @@ -7,13 +7,58 @@ import pytest from app.backtest.decision_env import DecisionEnv, EpisodeMetrics, ParameterSpec from app.backtest.engine import BacktestResult, BtConfig +from app.utils.config import DepartmentSettings, LLMConfig, LLMEndpoint + + +class _StubDepartmentAgent: + def __init__(self) -> None: + self._tool_choice = "auto" + self._max_rounds = 3 + endpoint = LLMEndpoint(provider="openai", model="mock", temperature=0.2) + self.settings = DepartmentSettings( + code="momentum", + title="Momentum", + description="baseline", + prompt="baseline", + llm=LLMConfig(primary=endpoint), + ) + + @property + def tool_choice(self) -> str: + return self._tool_choice + + @tool_choice.setter + def tool_choice(self, value) -> None: + normalized = str(value).strip().lower() + if normalized not in {"auto", "none", "required"}: + raise ValueError("invalid tool choice") + self._tool_choice = normalized + + @property + def max_rounds(self) -> int: + return self._max_rounds + + @max_rounds.setter + def max_rounds(self, value) -> None: + numeric = int(round(float(value))) + if numeric < 1: + numeric = 1 + if numeric > 6: + numeric = 6 + self._max_rounds = numeric + + +class _StubManager: + def __init__(self) -> None: + self.agents = {"momentum": _StubDepartmentAgent()} class _StubEngine: def __init__(self, cfg: BtConfig) -> None: # noqa: D401 self.cfg = cfg self.weights = {} - self.department_manager = None + self.department_manager = _StubManager() + _StubEngine.last_instance = self def run(self) -> BacktestResult: result = BacktestResult() @@ -53,6 +98,9 @@ class _StubEngine: return result +_StubEngine.last_instance: _StubEngine | None = None + + def test_decision_env_returns_risk_metrics(monkeypatch): cfg = BtConfig( id="stub", @@ -96,3 +144,68 @@ def test_default_reward_penalizes_metrics(): ) reward = DecisionEnv._default_reward(metrics) assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.1 * 0.3)) + + +def test_decision_env_department_controls(monkeypatch): + cfg = BtConfig( + id="stub", + name="stub", + start_date=date(2025, 1, 10), + end_date=date(2025, 1, 10), + universe=["000001.SZ"], + params={}, + ) + specs = [ + ParameterSpec(name="w_mom", target="agent_weights.A_mom", minimum=0.0, maximum=1.0), + ParameterSpec( + name="dept_prompt", + target="department.momentum.prompt", + values=["baseline", "aggressive"], + ), + ParameterSpec( + name="dept_temp", + target="department.momentum.temperature", + minimum=0.1, + maximum=0.9, + ), + ParameterSpec( + name="dept_tool", + target="department.momentum.function_policy", + values=["none", "auto", "required"], + ), + ParameterSpec( + name="dept_rounds", + target="department.momentum.max_rounds", + minimum=1, + maximum=5, + ), + ] + + env = DecisionEnv(bt_config=cfg, parameter_specs=specs, baseline_weights={"A_mom": 0.5}) + + monkeypatch.setattr("app.backtest.decision_env.BacktestEngine", _StubEngine) + monkeypatch.setattr(DecisionEnv, "_clear_portfolio_records", lambda self: None) + monkeypatch.setattr(DecisionEnv, "_fetch_portfolio_records", lambda self: ([], [])) + + obs, reward, done, info = env.step([0.3, 1.0, 0.75, 0.0, 1.0]) + + assert done is True + assert obs["total_return"] == pytest.approx(0.0) + + controls = info["department_controls"] + assert "momentum" in controls + momentum_ctrl = controls["momentum"] + assert momentum_ctrl["prompt"] == "aggressive" + assert momentum_ctrl["temperature"] == pytest.approx(0.7, abs=1e-6) + assert momentum_ctrl["tool_choice"] == "none" + assert momentum_ctrl["max_rounds"] == 5 + + assert env.last_department_controls == controls + + engine = _StubEngine.last_instance + assert engine is not None + agent = engine.department_manager.agents["momentum"] + assert agent.settings.prompt == "aggressive" + assert agent.settings.llm.primary.temperature == pytest.approx(0.7, abs=1e-6) + assert agent.tool_choice == "none" + assert agent.max_rounds == 5 diff --git a/tests/test_llm_templates.py b/tests/test_llm_templates.py index a9a9042..b314349 100644 --- a/tests/test_llm_templates.py +++ b/tests/test_llm_templates.py @@ -169,6 +169,12 @@ def test_default_templates(): assert momentum is not None assert "动量研究部门" in momentum.name + assert TemplateRegistry.get("value_dept") is not None + assert TemplateRegistry.get("news_dept") is not None + assert TemplateRegistry.get("liquidity_dept") is not None + assert TemplateRegistry.get("macro_dept") is not None + assert TemplateRegistry.get("risk_dept") is not None + # Validate template content assert all("{" + var + "}" in dept_base.template for var in dept_base.variables) assert all("{" + var + "}" in momentum.template for var in momentum.variables)