update
This commit is contained in:
parent
3351ca6b5a
commit
d0a0340db6
@ -93,6 +93,38 @@ class DepartmentAgent:
|
||||
self._resolver = resolver
|
||||
self._broker = DataBroker()
|
||||
self._max_rounds = 3
|
||||
self._tool_choice = "auto"
|
||||
|
||||
@property
|
||||
def max_rounds(self) -> int:
|
||||
return self._max_rounds
|
||||
|
||||
@max_rounds.setter
|
||||
def max_rounds(self, value: Any) -> None:
|
||||
try:
|
||||
numeric = int(round(float(value)))
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError("max_rounds must be numeric") from None
|
||||
if numeric < 1:
|
||||
numeric = 1
|
||||
if numeric > 6:
|
||||
numeric = 6
|
||||
self._max_rounds = numeric
|
||||
|
||||
@property
|
||||
def tool_choice(self) -> str:
|
||||
return self._tool_choice
|
||||
|
||||
@tool_choice.setter
|
||||
def tool_choice(self, value: Any) -> None:
|
||||
if value is None:
|
||||
self._tool_choice = "auto"
|
||||
return
|
||||
normalized = str(value).strip().lower()
|
||||
allowed = {"auto", "none", "required"}
|
||||
if normalized not in allowed:
|
||||
raise ValueError(f"Unsupported tool choice: {value}")
|
||||
self._tool_choice = normalized
|
||||
|
||||
def _get_llm_config(self) -> LLMConfig:
|
||||
if self._resolver:
|
||||
@ -159,7 +191,7 @@ class DepartmentAgent:
|
||||
primary_endpoint,
|
||||
messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
tool_choice=self._tool_choice,
|
||||
)
|
||||
except LLMError as exc:
|
||||
LOGGER.warning(
|
||||
|
||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
import copy
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
@ -18,17 +19,27 @@ LOG_EXTRA = {"stage": "decision_env"}
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ParameterSpec:
|
||||
"""Defines how a scalar action dimension maps to strategy parameters."""
|
||||
"""Defines how an action dimension maps to strategy parameters or behaviors."""
|
||||
|
||||
name: str
|
||||
target: str
|
||||
minimum: float = 0.0
|
||||
maximum: float = 1.0
|
||||
values: Optional[Sequence[Any]] = None
|
||||
|
||||
def clamp(self, value: float) -> float:
|
||||
clipped = max(0.0, min(1.0, float(value)))
|
||||
return self.minimum + clipped * (self.maximum - self.minimum)
|
||||
|
||||
def resolve(self, value: float) -> Any:
|
||||
if self.values is not None:
|
||||
if not self.values:
|
||||
raise ValueError(f"ParameterSpec {self.name} configured with empty values list")
|
||||
clipped = max(0.0, min(1.0, float(value)))
|
||||
index = int(round(clipped * (len(self.values) - 1)))
|
||||
return self.values[index]
|
||||
return self.clamp(value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeMetrics:
|
||||
@ -68,6 +79,7 @@ class DecisionEnv:
|
||||
self._reward_fn = reward_fn or self._default_reward
|
||||
self._last_metrics: Optional[EpisodeMetrics] = None
|
||||
self._last_action: Optional[Tuple[float, ...]] = None
|
||||
self._last_department_controls: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
self._episode = 0
|
||||
self._disable_departments = bool(disable_departments)
|
||||
|
||||
@ -75,10 +87,15 @@ class DecisionEnv:
|
||||
def action_dim(self) -> int:
|
||||
return len(self._specs)
|
||||
|
||||
@property
|
||||
def last_department_controls(self) -> Optional[Dict[str, Dict[str, Any]]]:
|
||||
return self._last_department_controls
|
||||
|
||||
def reset(self) -> Dict[str, float]:
|
||||
self._episode += 1
|
||||
self._last_metrics = None
|
||||
self._last_action = None
|
||||
self._last_department_controls = None
|
||||
return {
|
||||
"episode": float(self._episode),
|
||||
"baseline_return": 0.0,
|
||||
@ -90,14 +107,24 @@ class DecisionEnv:
|
||||
action_array = [float(val) for val in action]
|
||||
self._last_action = tuple(action_array)
|
||||
|
||||
weights = self._build_weights(action_array)
|
||||
LOGGER.info("episode=%s action=%s weights=%s", self._episode, action_array, weights, extra=LOG_EXTRA)
|
||||
weights, department_controls = self._prepare_actions(action_array)
|
||||
LOGGER.info(
|
||||
"episode=%s action=%s weights=%s controls=%s",
|
||||
self._episode,
|
||||
action_array,
|
||||
weights,
|
||||
department_controls,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
cfg = replace(self._template_cfg)
|
||||
engine = BacktestEngine(cfg)
|
||||
engine.weights = weight_map(weights)
|
||||
if self._disable_departments:
|
||||
engine.department_manager = None
|
||||
applied_controls: Dict[str, Dict[str, Any]] = {}
|
||||
else:
|
||||
applied_controls = self._apply_department_controls(engine, department_controls)
|
||||
|
||||
self._clear_portfolio_records()
|
||||
|
||||
@ -135,19 +162,153 @@ class DecisionEnv:
|
||||
"risk_events": getattr(result, "risk_events", []),
|
||||
"portfolio_snapshots": snapshots,
|
||||
"portfolio_trades": trades_override,
|
||||
"department_controls": applied_controls,
|
||||
}
|
||||
self._last_department_controls = applied_controls
|
||||
return observation, reward, True, info
|
||||
|
||||
def _build_weights(self, action: Sequence[float]) -> Dict[str, float]:
|
||||
def _prepare_actions(
|
||||
self,
|
||||
action: Sequence[float],
|
||||
) -> Tuple[Dict[str, float], Dict[str, Dict[str, Any]]]:
|
||||
weights = dict(self._baseline_weights)
|
||||
department_controls: Dict[str, Dict[str, Any]] = {}
|
||||
for idx, spec in enumerate(self._specs):
|
||||
value = spec.clamp(action[idx])
|
||||
try:
|
||||
resolved = spec.resolve(action[idx])
|
||||
except ValueError as exc:
|
||||
LOGGER.warning("参数 %s 解析失败:%s", spec.name, exc, extra=LOG_EXTRA)
|
||||
continue
|
||||
if spec.target.startswith("agent_weights."):
|
||||
agent_name = spec.target.split(".", 1)[1]
|
||||
weights[agent_name] = value
|
||||
try:
|
||||
weights[agent_name] = float(resolved)
|
||||
except (TypeError, ValueError):
|
||||
LOGGER.debug(
|
||||
"spec %s produced non-numeric weight %s; skipping",
|
||||
spec.name,
|
||||
resolved,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
if spec.target.startswith("department."):
|
||||
target_path = spec.target.split(".")[1:]
|
||||
if len(target_path) < 2:
|
||||
LOGGER.debug("未识别的部门目标:%s", spec.target, extra=LOG_EXTRA)
|
||||
continue
|
||||
dept_code = target_path[0]
|
||||
field = ".".join(target_path[1:])
|
||||
dept_controls = department_controls.setdefault(dept_code, {})
|
||||
dept_controls[field] = resolved
|
||||
continue
|
||||
else:
|
||||
LOGGER.debug("暂未支持的参数目标:%s", spec.target, extra=LOG_EXTRA)
|
||||
return weights
|
||||
return weights, department_controls
|
||||
|
||||
def _apply_department_controls(
|
||||
self,
|
||||
engine: BacktestEngine,
|
||||
controls: Mapping[str, Mapping[str, Any]],
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
manager = getattr(engine, "department_manager", None)
|
||||
if not manager or not getattr(manager, "agents", None):
|
||||
return {}
|
||||
|
||||
applied: Dict[str, Dict[str, Any]] = {}
|
||||
for dept_code, payload in controls.items():
|
||||
agent = manager.agents.get(dept_code)
|
||||
if not agent or not isinstance(payload, Mapping):
|
||||
continue
|
||||
|
||||
applied_fields: Dict[str, Any] = {}
|
||||
|
||||
# Ensure mutable settings clone to avoid global side-effects
|
||||
try:
|
||||
original_settings = agent.settings
|
||||
cloned_settings = replace(original_settings)
|
||||
cloned_settings.llm = copy.deepcopy(original_settings.llm)
|
||||
agent.settings = cloned_settings
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.warning(
|
||||
"复制部门 %s 配置失败:%s",
|
||||
dept_code,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
|
||||
for raw_field, value in payload.items():
|
||||
field = raw_field.lower()
|
||||
if field == "function_policy":
|
||||
field = "tool_choice"
|
||||
if field in {"prompt", "instruction"}:
|
||||
agent.settings.prompt = str(value)
|
||||
applied_fields[field] = agent.settings.prompt
|
||||
continue
|
||||
if field == "description":
|
||||
agent.settings.description = str(value)
|
||||
applied_fields[field] = agent.settings.description
|
||||
continue
|
||||
if field in {"prompt_template_id", "prompt_template"}:
|
||||
agent.settings.prompt_template_id = str(value)
|
||||
applied_fields["prompt_template_id"] = agent.settings.prompt_template_id
|
||||
continue
|
||||
if field == "prompt_template_version":
|
||||
agent.settings.prompt_template_version = str(value)
|
||||
applied_fields["prompt_template_version"] = agent.settings.prompt_template_version
|
||||
continue
|
||||
if field in {"temperature", "llm.temperature"}:
|
||||
try:
|
||||
temperature = max(0.0, min(2.0, float(value)))
|
||||
agent.settings.llm.primary.temperature = temperature
|
||||
applied_fields["temperature"] = temperature
|
||||
except (TypeError, ValueError):
|
||||
LOGGER.debug(
|
||||
"无效的温度值 %s for %s",
|
||||
value,
|
||||
dept_code,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
if field in {"tool_choice", "tool_strategy"}:
|
||||
try:
|
||||
agent.tool_choice = value
|
||||
applied_fields["tool_choice"] = agent.tool_choice
|
||||
except ValueError:
|
||||
LOGGER.debug(
|
||||
"部门 %s 工具策略 %s 无效",
|
||||
dept_code,
|
||||
value,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
if field == "max_rounds":
|
||||
try:
|
||||
agent.max_rounds = value
|
||||
applied_fields["max_rounds"] = agent.max_rounds
|
||||
except ValueError:
|
||||
LOGGER.debug(
|
||||
"部门 %s max_rounds %s 无效",
|
||||
dept_code,
|
||||
value,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
continue
|
||||
if field == "prompt_template_override":
|
||||
agent.settings.prompt = str(value)
|
||||
applied_fields["prompt"] = agent.settings.prompt
|
||||
continue
|
||||
LOGGER.debug(
|
||||
"部门 %s 未识别的控制项 %s",
|
||||
dept_code,
|
||||
raw_field,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
if applied_fields:
|
||||
applied[dept_code] = applied_fields
|
||||
|
||||
return applied
|
||||
|
||||
def _compute_metrics(
|
||||
self,
|
||||
|
||||
@ -252,98 +252,352 @@ class TemplateRegistry:
|
||||
DEFAULT_TEMPLATES = {
|
||||
"department_base": {
|
||||
"name": "部门基础模板",
|
||||
"description": "通用的部门分析提示模板",
|
||||
"description": "所有部门通用的审慎分析提示词骨架",
|
||||
"template": """
|
||||
部门名称:{title}
|
||||
部门:{title}
|
||||
股票代码:{ts_code}
|
||||
交易日:{trade_date}
|
||||
|
||||
角色说明:{description}
|
||||
职责指令:{instruction}
|
||||
【角色定位】
|
||||
- 角色说明:{description}
|
||||
- 行动守则:{instruction}
|
||||
|
||||
【可用数据范围】
|
||||
【数据边界】
|
||||
- 可用字段:
|
||||
{data_scope}
|
||||
|
||||
【核心特征】
|
||||
- 核心特征:
|
||||
{features}
|
||||
|
||||
【市场背景】
|
||||
- 市场背景:
|
||||
{market_snapshot}
|
||||
|
||||
【追加数据】
|
||||
- 追加数据:
|
||||
{supplements}
|
||||
|
||||
请基于以上数据给出该部门对当前股票的操作建议。输出必须是 JSON,字段如下:
|
||||
【分析步骤】
|
||||
1. 判断信息是否充分,如不充分,请说明缺口并优先调用工具 `fetch_data`(仅限 `daily`、`daily_basic`)。
|
||||
2. 梳理 2-3 个关键支撑信号与潜在风险,确保基于提供的数据。
|
||||
3. 结合量化证据与限制条件,给出操作建议和信心来源,避免主观臆测。
|
||||
|
||||
【输出要求】
|
||||
仅返回一个 JSON 对象,不要添加额外文本:
|
||||
{{
|
||||
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
|
||||
"confidence": 0-1 之间的小数,表示信心,
|
||||
"summary": "一句话概括理由",
|
||||
"signals": ["详细要点", "..."],
|
||||
"risks": ["风险点", "..."]
|
||||
"confidence": 0-1 之间的小数,
|
||||
"summary": "一句话结论",
|
||||
"signals": ["关键支撑要点", "..."],
|
||||
"risks": ["关键风险要点", "..."]
|
||||
}}
|
||||
|
||||
如需额外数据,请调用工具 `fetch_data`,仅支持请求 `daily` 或 `daily_basic` 表。
|
||||
请严格返回单个 JSON 对象,不要添加额外文本。
|
||||
如需说明未完成的数据请求,请在 `risks` 或 `signals` 中明确。
|
||||
""",
|
||||
"variables": [
|
||||
"title", "ts_code", "trade_date", "description", "instruction",
|
||||
"data_scope", "features", "market_snapshot", "supplements"
|
||||
"title",
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"description",
|
||||
"instruction",
|
||||
"data_scope",
|
||||
"features",
|
||||
"market_snapshot",
|
||||
"supplements",
|
||||
],
|
||||
"required_context": [
|
||||
"ts_code", "trade_date", "features", "market_snapshot"
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"features",
|
||||
"market_snapshot",
|
||||
],
|
||||
"validation_rules": [
|
||||
"len(features) > 0",
|
||||
"len(market_snapshot) > 0"
|
||||
]
|
||||
"metadata": {
|
||||
"category": "department",
|
||||
"preset": "base",
|
||||
},
|
||||
},
|
||||
"momentum_dept": {
|
||||
"name": "动量研究部门",
|
||||
"description": "专注于动量因子分析的部门模板",
|
||||
"name": "动量研究部门模板",
|
||||
"description": "围绕价格与量能动量的决策提示",
|
||||
"template": """
|
||||
部门名称:动量研究部门
|
||||
部门:动量研究部门
|
||||
股票代码:{ts_code}
|
||||
交易日:{trade_date}
|
||||
|
||||
角色说明:专注于分析股票价格动量、成交量动量和技术指标动量
|
||||
职责指令:重点关注以下方面:
|
||||
1. 价格趋势强度和持续性
|
||||
2. 成交量配合度
|
||||
3. 技术指标背离
|
||||
【角色定位】
|
||||
- 专注价格动量、成交量共振与技术指标背离。
|
||||
- 保持纪律,识别趋势延续与反转风险。
|
||||
|
||||
【可用数据范围】
|
||||
【研究重点】
|
||||
1. 多时间窗口动量是否同向?
|
||||
2. 成交量是否验证价格走势?
|
||||
3. 是否出现过热或背离信号?
|
||||
|
||||
【数据边界】
|
||||
- 可用字段:
|
||||
{data_scope}
|
||||
|
||||
【动量特征】
|
||||
- 动量特征:
|
||||
{features}
|
||||
|
||||
【市场背景】
|
||||
- 市场背景:
|
||||
{market_snapshot}
|
||||
|
||||
【追加数据】
|
||||
- 追加数据:
|
||||
{supplements}
|
||||
|
||||
请基于以上数据进行动量分析并给出操作建议。输出必须是 JSON,字段如下:
|
||||
{{
|
||||
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
|
||||
"confidence": 0-1 之间的小数,表示信心,
|
||||
"summary": "一句话概括动量分析结论",
|
||||
"signals": ["动量信号要点", "..."],
|
||||
"risks": ["动量风险点", "..."]
|
||||
}}
|
||||
请沿用【部门基础模板】的分析步骤与输出要求,重点量化趋势动能和量价配合度。
|
||||
""",
|
||||
"variables": [
|
||||
"ts_code", "trade_date", "data_scope",
|
||||
"features", "market_snapshot", "supplements"
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"data_scope",
|
||||
"features",
|
||||
"market_snapshot",
|
||||
"supplements",
|
||||
],
|
||||
"required_context": [
|
||||
"ts_code", "trade_date", "features", "market_snapshot"
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"features",
|
||||
"market_snapshot",
|
||||
],
|
||||
"validation_rules": [
|
||||
"len(features) > 0",
|
||||
"'momentum' in ' '.join(features.keys()).lower()"
|
||||
]
|
||||
}
|
||||
"metadata": {
|
||||
"category": "department",
|
||||
"preset": "momentum",
|
||||
},
|
||||
},
|
||||
"value_dept": {
|
||||
"name": "价值评估部门模板",
|
||||
"description": "衡量估值与盈利质量的提示词",
|
||||
"template": """
|
||||
部门:价值评估部门
|
||||
股票代码:{ts_code}
|
||||
交易日:{trade_date}
|
||||
|
||||
【角色定位】
|
||||
- 关注估值分位、盈利质量与安全边际。
|
||||
- 从中期配置角度评价当前价格的性价比。
|
||||
|
||||
【研究重点】
|
||||
1. 历史及同业视角的估值位置。
|
||||
2. 盈利与分红的可持续性。
|
||||
3. 潜在的估值修复催化或压制因素。
|
||||
|
||||
【数据边界】
|
||||
- 可用字段:
|
||||
{data_scope}
|
||||
- 估值与质量特征:
|
||||
{features}
|
||||
- 市场背景:
|
||||
{market_snapshot}
|
||||
- 追加数据:
|
||||
{supplements}
|
||||
|
||||
请按照【部门基础模板】的分析步骤输出结论,并明确估值安全边际来源。
|
||||
""",
|
||||
"variables": [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"data_scope",
|
||||
"features",
|
||||
"market_snapshot",
|
||||
"supplements",
|
||||
],
|
||||
"required_context": [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"features",
|
||||
"market_snapshot",
|
||||
],
|
||||
"metadata": {
|
||||
"category": "department",
|
||||
"preset": "value",
|
||||
},
|
||||
},
|
||||
"news_dept": {
|
||||
"name": "新闻情绪部门模板",
|
||||
"description": "针对舆情热度与事件影响的提示词",
|
||||
"template": """
|
||||
部门:新闻情绪部门
|
||||
股票代码:{ts_code}
|
||||
交易日:{trade_date}
|
||||
|
||||
【角色定位】
|
||||
- 监控舆情热度、事件驱动与短期情绪。
|
||||
- 评估新闻对价格波动的正负面影响。
|
||||
|
||||
【研究重点】
|
||||
1. 新闻情绪是否集中且持续?
|
||||
2. 主题与行情是否匹配?
|
||||
3. 情绪驱动的风险敞口。
|
||||
|
||||
【数据边界】
|
||||
- 可用字段:
|
||||
{data_scope}
|
||||
- 舆情特征:
|
||||
{features}
|
||||
- 市场背景:
|
||||
{market_snapshot}
|
||||
- 追加数据:
|
||||
{supplements}
|
||||
|
||||
请遵循【部门基础模板】的分析步骤,突出情绪驱动的力度与时效性。
|
||||
""",
|
||||
"variables": [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"data_scope",
|
||||
"features",
|
||||
"market_snapshot",
|
||||
"supplements",
|
||||
],
|
||||
"required_context": [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"features",
|
||||
"market_snapshot",
|
||||
],
|
||||
"metadata": {
|
||||
"category": "department",
|
||||
"preset": "news",
|
||||
},
|
||||
},
|
||||
"liquidity_dept": {
|
||||
"name": "流动性评估部门模板",
|
||||
"description": "衡量成交活跃度与执行成本的提示词",
|
||||
"template": """
|
||||
部门:流动性评估部门
|
||||
股票代码:{ts_code}
|
||||
交易日:{trade_date}
|
||||
|
||||
【角色定位】
|
||||
- 评估成交活跃度、交易成本与可执行性。
|
||||
- 提醒潜在的流动性风险与仓位限制。
|
||||
|
||||
【研究重点】
|
||||
1. 当前成交量与历史均值的对比。
|
||||
2. 价量限制(涨跌停、停牌等)对执行的影响。
|
||||
3. 预估滑点与转手难度。
|
||||
|
||||
【数据边界】
|
||||
- 可用字段:
|
||||
{data_scope}
|
||||
- 流动性特征:
|
||||
{features}
|
||||
- 市场背景:
|
||||
{market_snapshot}
|
||||
- 追加数据:
|
||||
{supplements}
|
||||
|
||||
请遵循【部门基础模板】的分析步骤,重点描述执行可行性与仓位建议。
|
||||
""",
|
||||
"variables": [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"data_scope",
|
||||
"features",
|
||||
"market_snapshot",
|
||||
"supplements",
|
||||
],
|
||||
"required_context": [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"features",
|
||||
"market_snapshot",
|
||||
],
|
||||
"metadata": {
|
||||
"category": "department",
|
||||
"preset": "liquidity",
|
||||
},
|
||||
},
|
||||
"macro_dept": {
|
||||
"name": "宏观研究部门模板",
|
||||
"description": "宏观与行业景气度分析提示词",
|
||||
"template": """
|
||||
部门:宏观研究部门
|
||||
股票代码:{ts_code}
|
||||
交易日:{trade_date}
|
||||
|
||||
【角色定位】
|
||||
- 追踪宏观周期、行业景气与相对强弱。
|
||||
- 评估宏观事件对该标的的方向性影响。
|
||||
|
||||
【研究重点】
|
||||
1. 行业相对大盘的表现与热点程度。
|
||||
2. 宏观/政策事件对行业或标的的指引。
|
||||
3. 需警惕的宏观风险与流动性环境。
|
||||
|
||||
【数据边界】
|
||||
- 可用字段:
|
||||
{data_scope}
|
||||
- 宏观特征:
|
||||
{features}
|
||||
- 市场背景:
|
||||
{market_snapshot}
|
||||
- 追加数据:
|
||||
{supplements}
|
||||
|
||||
请执行【部门基础模板】的分析步骤,并输出宏观驱动的信号与风险。
|
||||
""",
|
||||
"variables": [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"data_scope",
|
||||
"features",
|
||||
"market_snapshot",
|
||||
"supplements",
|
||||
],
|
||||
"required_context": [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"features",
|
||||
"market_snapshot",
|
||||
],
|
||||
"metadata": {
|
||||
"category": "department",
|
||||
"preset": "macro",
|
||||
},
|
||||
},
|
||||
"risk_dept": {
|
||||
"name": "风险控制部门模板",
|
||||
"description": "识别极端风险与限制条件的提示词",
|
||||
"template": """
|
||||
部门:风险控制部门
|
||||
股票代码:{ts_code}
|
||||
交易日:{trade_date}
|
||||
|
||||
【角色定位】
|
||||
- 防范停牌、涨跌停、仓位与合规限制。
|
||||
- 必要时对高风险决策行使否决权。
|
||||
|
||||
【研究重点】
|
||||
1. 交易限制或异常波动情况。
|
||||
2. 仓位、集中度或风险指标是否触顶。
|
||||
3. 潜在的黑天鹅或执行障碍。
|
||||
|
||||
【数据边界】
|
||||
- 可用字段:
|
||||
{data_scope}
|
||||
- 风险特征:
|
||||
{features}
|
||||
- 市场背景:
|
||||
{market_snapshot}
|
||||
- 追加数据:
|
||||
{supplements}
|
||||
|
||||
请按照【部门基础模板】的分析步骤,必要时明确阻止交易的理由。
|
||||
""",
|
||||
"variables": [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"data_scope",
|
||||
"features",
|
||||
"market_snapshot",
|
||||
"supplements",
|
||||
],
|
||||
"required_context": [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"features",
|
||||
"market_snapshot",
|
||||
],
|
||||
"metadata": {
|
||||
"category": "department",
|
||||
"preset": "risk",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -331,6 +331,7 @@ def render_backtest_review() -> None:
|
||||
)
|
||||
|
||||
specs: List[ParameterSpec] = []
|
||||
spec_labels: List[str] = []
|
||||
action_values: List[float] = []
|
||||
range_valid = True
|
||||
for idx, agent_name in enumerate(selected_agents):
|
||||
@ -374,15 +375,111 @@ def render_backtest_review() -> None:
|
||||
maximum=max_val,
|
||||
)
|
||||
)
|
||||
spec_labels.append(f"agent:{agent_name}")
|
||||
action_values.append(action_val)
|
||||
|
||||
controls_valid = True
|
||||
with st.expander("部门 LLM 参数", expanded=False):
|
||||
dept_codes = sorted(app_cfg.departments.keys())
|
||||
if not dept_codes:
|
||||
st.caption("当前未配置部门。")
|
||||
else:
|
||||
selected_departments = st.multiselect(
|
||||
"选择需要调整的部门",
|
||||
dept_codes,
|
||||
default=[],
|
||||
key="decision_env_departments",
|
||||
)
|
||||
tool_policy_values = ["auto", "none", "required"]
|
||||
for dept_code in selected_departments:
|
||||
settings = app_cfg.departments.get(dept_code)
|
||||
if not settings:
|
||||
continue
|
||||
st.subheader(f"部门:{settings.title or dept_code}")
|
||||
base_temp = 0.2
|
||||
if settings.llm and settings.llm.primary and settings.llm.primary.temperature is not None:
|
||||
base_temp = float(settings.llm.primary.temperature)
|
||||
prefix = f"decision_env_dept_{dept_code}"
|
||||
col_tmin, col_tmax, col_tslider = st.columns([1, 1, 2])
|
||||
temp_min = col_tmin.number_input(
|
||||
"温度最小值",
|
||||
min_value=0.0,
|
||||
max_value=2.0,
|
||||
value=max(0.0, base_temp - 0.3),
|
||||
step=0.05,
|
||||
key=f"{prefix}_temp_min",
|
||||
)
|
||||
temp_max = col_tmax.number_input(
|
||||
"温度最大值",
|
||||
min_value=0.0,
|
||||
max_value=2.0,
|
||||
value=min(2.0, base_temp + 0.3),
|
||||
step=0.05,
|
||||
key=f"{prefix}_temp_max",
|
||||
)
|
||||
if temp_max <= temp_min:
|
||||
controls_valid = False
|
||||
st.warning("温度最大值必须大于最小值。")
|
||||
temp_max = min(2.0, temp_min + 0.01)
|
||||
span = temp_max - temp_min
|
||||
if span <= 0:
|
||||
ratio_default = 0.0
|
||||
else:
|
||||
clamped = min(max(base_temp, temp_min), temp_max)
|
||||
ratio_default = (clamped - temp_min) / span
|
||||
temp_action = col_tslider.slider(
|
||||
"动作值(映射至温度区间)",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=float(ratio_default),
|
||||
step=0.01,
|
||||
key=f"{prefix}_temp_action",
|
||||
)
|
||||
specs.append(
|
||||
ParameterSpec(
|
||||
name=f"dept_temperature_{dept_code}",
|
||||
target=f"department.{dept_code}.temperature",
|
||||
minimum=temp_min,
|
||||
maximum=temp_max,
|
||||
)
|
||||
)
|
||||
spec_labels.append(f"department:{dept_code}:temperature")
|
||||
action_values.append(temp_action)
|
||||
|
||||
col_tool, col_hint = st.columns([1, 2])
|
||||
tool_choice = col_tool.selectbox(
|
||||
"函数调用策略",
|
||||
tool_policy_values,
|
||||
index=tool_policy_values.index("auto"),
|
||||
key=f"{prefix}_tool_choice",
|
||||
)
|
||||
col_hint.caption("映射提示:0→auto,0.5→none,1→required。")
|
||||
if len(tool_policy_values) > 1:
|
||||
tool_value = tool_policy_values.index(tool_choice) / (len(tool_policy_values) - 1)
|
||||
else:
|
||||
tool_value = 0.0
|
||||
specs.append(
|
||||
ParameterSpec(
|
||||
name=f"dept_tool_{dept_code}",
|
||||
target=f"department.{dept_code}.function_policy",
|
||||
values=tool_policy_values,
|
||||
)
|
||||
)
|
||||
spec_labels.append(f"department:{dept_code}:tool_choice")
|
||||
action_values.append(tool_value)
|
||||
|
||||
if specs:
|
||||
st.caption("动作维度顺序:" + ",".join(spec_labels))
|
||||
|
||||
run_decision_env = st.button("执行单次调参", key="run_decision_env_button")
|
||||
just_finished_single = False
|
||||
if run_decision_env:
|
||||
if not selected_agents:
|
||||
st.warning("请至少选择一个代理进行调参。")
|
||||
elif not range_valid:
|
||||
if not specs:
|
||||
st.warning("请至少配置一个动作维度(代理或部门参数)。")
|
||||
elif selected_agents and not range_valid:
|
||||
st.error("请确保所有代理的最大权重大于最小权重。")
|
||||
elif not controls_valid:
|
||||
st.error("请修正部门参数的取值范围。")
|
||||
else:
|
||||
LOGGER.info(
|
||||
"离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s",
|
||||
@ -448,11 +545,11 @@ def render_backtest_review() -> None:
|
||||
resolved_experiment_id = experiment_id or str(uuid.uuid4())
|
||||
resolved_strategy = strategy_label or "DecisionEnv"
|
||||
action_payload = {
|
||||
name: value
|
||||
for name, value in zip(selected_agents, action_values)
|
||||
label: value for label, value in zip(spec_labels, action_values)
|
||||
}
|
||||
metrics_payload = dict(observation)
|
||||
metrics_payload["reward"] = reward
|
||||
metrics_payload["department_controls"] = info.get("department_controls")
|
||||
log_success = False
|
||||
try:
|
||||
log_tuning_result(
|
||||
@ -477,12 +574,14 @@ def render_backtest_review() -> None:
|
||||
"observation": dict(observation),
|
||||
"reward": float(reward),
|
||||
"weights": info.get("weights", {}),
|
||||
"department_controls": info.get("department_controls"),
|
||||
"actions": action_payload,
|
||||
"nav_series": info.get("nav_series"),
|
||||
"trades": info.get("trades"),
|
||||
"portfolio_snapshots": info.get("portfolio_snapshots"),
|
||||
"portfolio_trades": info.get("portfolio_trades"),
|
||||
"risk_breakdown": info.get("risk_breakdown"),
|
||||
"selected_agents": list(selected_agents),
|
||||
"spec_labels": list(spec_labels),
|
||||
"action_values": list(action_values),
|
||||
"experiment_id": resolved_experiment_id,
|
||||
"strategy_label": resolved_strategy,
|
||||
@ -562,6 +661,16 @@ def render_backtest_review() -> None:
|
||||
with st.expander("风险事件统计", expanded=False):
|
||||
st.json(risk_breakdown)
|
||||
|
||||
department_info = single_result.get("department_controls") or {}
|
||||
if department_info:
|
||||
with st.expander("部门控制参数", expanded=False):
|
||||
st.json(department_info)
|
||||
|
||||
action_snapshot = single_result.get("actions") or {}
|
||||
if action_snapshot:
|
||||
with st.expander("动作明细", expanded=False):
|
||||
st.json(action_snapshot)
|
||||
|
||||
if st.button("清除单次调参结果", key="clear_decision_env_single"):
|
||||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||||
st.success("已清除单次调参结果缓存。")
|
||||
@ -584,10 +693,12 @@ def render_backtest_review() -> None:
|
||||
run_batch = st.button("批量执行调参", key="run_decision_env_batch")
|
||||
batch_just_ran = False
|
||||
if run_batch:
|
||||
if not selected_agents:
|
||||
st.warning("请先选择调参代理。")
|
||||
elif not range_valid:
|
||||
if not specs:
|
||||
st.warning("请至少配置一个动作维度。")
|
||||
elif selected_agents and not range_valid:
|
||||
st.error("请确保所有代理的最大权重大于最小权重。")
|
||||
elif not controls_valid:
|
||||
st.error("请修正部门参数的取值范围。")
|
||||
else:
|
||||
LOGGER.info(
|
||||
"离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s",
|
||||
@ -693,11 +804,12 @@ def render_backtest_review() -> None:
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
action_payload = {
|
||||
name: value
|
||||
for name, value in zip(selected_agents, action_vals)
|
||||
label: value
|
||||
for label, value in zip(spec_labels, action_vals)
|
||||
}
|
||||
metrics_payload = dict(observation)
|
||||
metrics_payload["reward"] = reward
|
||||
metrics_payload["department_controls"] = info.get("department_controls")
|
||||
weights_payload = info.get("weights", {})
|
||||
try:
|
||||
log_tuning_result(
|
||||
@ -713,13 +825,14 @@ def render_backtest_review() -> None:
|
||||
results.append(
|
||||
{
|
||||
"序号": idx,
|
||||
"动作": action_vals,
|
||||
"动作": action_payload,
|
||||
"状态": "ok",
|
||||
"总收益": observation.get("total_return", 0.0),
|
||||
"最大回撤": observation.get("max_drawdown", 0.0),
|
||||
"波动率": observation.get("volatility", 0.0),
|
||||
"奖励": reward,
|
||||
"权重": weights_payload,
|
||||
"部门控制": info.get("department_controls"),
|
||||
}
|
||||
)
|
||||
st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = {
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
- 围绕动量、估值、流动性等核心信号扩展轻量高质量因子集,全部由程序生成,满足端到端自动决策需求。
|
||||
|
||||
## 3. 决策优化与强化学习
|
||||
- 扩展 `DecisionEnv` 的动作空间(提示版本、部门温度、function 调用策略等),不仅限于代理权重调节。
|
||||
- ✅ 扩展 `DecisionEnv` 的动作空间(提示版本、部门温度、function 调用策略等),不仅限于代理权重调节。
|
||||
- 引入 Bandit / 贝叶斯优化或 RL 算法探索动作空间,并将 `portfolio_snapshots`、`portfolio_trades` 指标纳入奖励约束。
|
||||
- 构建实时持仓/成交数据写入链路,使线上监控与离线调参共用同一数据源。
|
||||
- 借鉴 TradingAgents-CN 的做法:拆分环境与策略、提供训练脚本/配置,并输出丰富的评估指标(如 Sharpe、Sortino、基准对比)。
|
||||
|
||||
@ -7,13 +7,58 @@ import pytest
|
||||
|
||||
from app.backtest.decision_env import DecisionEnv, EpisodeMetrics, ParameterSpec
|
||||
from app.backtest.engine import BacktestResult, BtConfig
|
||||
from app.utils.config import DepartmentSettings, LLMConfig, LLMEndpoint
|
||||
|
||||
|
||||
class _StubDepartmentAgent:
|
||||
def __init__(self) -> None:
|
||||
self._tool_choice = "auto"
|
||||
self._max_rounds = 3
|
||||
endpoint = LLMEndpoint(provider="openai", model="mock", temperature=0.2)
|
||||
self.settings = DepartmentSettings(
|
||||
code="momentum",
|
||||
title="Momentum",
|
||||
description="baseline",
|
||||
prompt="baseline",
|
||||
llm=LLMConfig(primary=endpoint),
|
||||
)
|
||||
|
||||
@property
|
||||
def tool_choice(self) -> str:
|
||||
return self._tool_choice
|
||||
|
||||
@tool_choice.setter
|
||||
def tool_choice(self, value) -> None:
|
||||
normalized = str(value).strip().lower()
|
||||
if normalized not in {"auto", "none", "required"}:
|
||||
raise ValueError("invalid tool choice")
|
||||
self._tool_choice = normalized
|
||||
|
||||
@property
|
||||
def max_rounds(self) -> int:
|
||||
return self._max_rounds
|
||||
|
||||
@max_rounds.setter
|
||||
def max_rounds(self, value) -> None:
|
||||
numeric = int(round(float(value)))
|
||||
if numeric < 1:
|
||||
numeric = 1
|
||||
if numeric > 6:
|
||||
numeric = 6
|
||||
self._max_rounds = numeric
|
||||
|
||||
|
||||
class _StubManager:
|
||||
def __init__(self) -> None:
|
||||
self.agents = {"momentum": _StubDepartmentAgent()}
|
||||
|
||||
|
||||
class _StubEngine:
|
||||
def __init__(self, cfg: BtConfig) -> None: # noqa: D401
|
||||
self.cfg = cfg
|
||||
self.weights = {}
|
||||
self.department_manager = None
|
||||
self.department_manager = _StubManager()
|
||||
_StubEngine.last_instance = self
|
||||
|
||||
def run(self) -> BacktestResult:
|
||||
result = BacktestResult()
|
||||
@ -53,6 +98,9 @@ class _StubEngine:
|
||||
return result
|
||||
|
||||
|
||||
_StubEngine.last_instance: _StubEngine | None = None
|
||||
|
||||
|
||||
def test_decision_env_returns_risk_metrics(monkeypatch):
|
||||
cfg = BtConfig(
|
||||
id="stub",
|
||||
@ -96,3 +144,68 @@ def test_default_reward_penalizes_metrics():
|
||||
)
|
||||
reward = DecisionEnv._default_reward(metrics)
|
||||
assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.1 * 0.3))
|
||||
|
||||
|
||||
def test_decision_env_department_controls(monkeypatch):
|
||||
cfg = BtConfig(
|
||||
id="stub",
|
||||
name="stub",
|
||||
start_date=date(2025, 1, 10),
|
||||
end_date=date(2025, 1, 10),
|
||||
universe=["000001.SZ"],
|
||||
params={},
|
||||
)
|
||||
specs = [
|
||||
ParameterSpec(name="w_mom", target="agent_weights.A_mom", minimum=0.0, maximum=1.0),
|
||||
ParameterSpec(
|
||||
name="dept_prompt",
|
||||
target="department.momentum.prompt",
|
||||
values=["baseline", "aggressive"],
|
||||
),
|
||||
ParameterSpec(
|
||||
name="dept_temp",
|
||||
target="department.momentum.temperature",
|
||||
minimum=0.1,
|
||||
maximum=0.9,
|
||||
),
|
||||
ParameterSpec(
|
||||
name="dept_tool",
|
||||
target="department.momentum.function_policy",
|
||||
values=["none", "auto", "required"],
|
||||
),
|
||||
ParameterSpec(
|
||||
name="dept_rounds",
|
||||
target="department.momentum.max_rounds",
|
||||
minimum=1,
|
||||
maximum=5,
|
||||
),
|
||||
]
|
||||
|
||||
env = DecisionEnv(bt_config=cfg, parameter_specs=specs, baseline_weights={"A_mom": 0.5})
|
||||
|
||||
monkeypatch.setattr("app.backtest.decision_env.BacktestEngine", _StubEngine)
|
||||
monkeypatch.setattr(DecisionEnv, "_clear_portfolio_records", lambda self: None)
|
||||
monkeypatch.setattr(DecisionEnv, "_fetch_portfolio_records", lambda self: ([], []))
|
||||
|
||||
obs, reward, done, info = env.step([0.3, 1.0, 0.75, 0.0, 1.0])
|
||||
|
||||
assert done is True
|
||||
assert obs["total_return"] == pytest.approx(0.0)
|
||||
|
||||
controls = info["department_controls"]
|
||||
assert "momentum" in controls
|
||||
momentum_ctrl = controls["momentum"]
|
||||
assert momentum_ctrl["prompt"] == "aggressive"
|
||||
assert momentum_ctrl["temperature"] == pytest.approx(0.7, abs=1e-6)
|
||||
assert momentum_ctrl["tool_choice"] == "none"
|
||||
assert momentum_ctrl["max_rounds"] == 5
|
||||
|
||||
assert env.last_department_controls == controls
|
||||
|
||||
engine = _StubEngine.last_instance
|
||||
assert engine is not None
|
||||
agent = engine.department_manager.agents["momentum"]
|
||||
assert agent.settings.prompt == "aggressive"
|
||||
assert agent.settings.llm.primary.temperature == pytest.approx(0.7, abs=1e-6)
|
||||
assert agent.tool_choice == "none"
|
||||
assert agent.max_rounds == 5
|
||||
|
||||
@ -169,6 +169,12 @@ def test_default_templates():
|
||||
assert momentum is not None
|
||||
assert "动量研究部门" in momentum.name
|
||||
|
||||
assert TemplateRegistry.get("value_dept") is not None
|
||||
assert TemplateRegistry.get("news_dept") is not None
|
||||
assert TemplateRegistry.get("liquidity_dept") is not None
|
||||
assert TemplateRegistry.get("macro_dept") is not None
|
||||
assert TemplateRegistry.get("risk_dept") is not None
|
||||
|
||||
# Validate template content
|
||||
assert all("{" + var + "}" in dept_base.template for var in dept_base.variables)
|
||||
assert all("{" + var + "}" in momentum.template for var in momentum.variables)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user