update
This commit is contained in:
parent
3351ca6b5a
commit
d0a0340db6
@ -93,6 +93,38 @@ class DepartmentAgent:
|
|||||||
self._resolver = resolver
|
self._resolver = resolver
|
||||||
self._broker = DataBroker()
|
self._broker = DataBroker()
|
||||||
self._max_rounds = 3
|
self._max_rounds = 3
|
||||||
|
self._tool_choice = "auto"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_rounds(self) -> int:
|
||||||
|
return self._max_rounds
|
||||||
|
|
||||||
|
@max_rounds.setter
|
||||||
|
def max_rounds(self, value: Any) -> None:
|
||||||
|
try:
|
||||||
|
numeric = int(round(float(value)))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
raise ValueError("max_rounds must be numeric") from None
|
||||||
|
if numeric < 1:
|
||||||
|
numeric = 1
|
||||||
|
if numeric > 6:
|
||||||
|
numeric = 6
|
||||||
|
self._max_rounds = numeric
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tool_choice(self) -> str:
|
||||||
|
return self._tool_choice
|
||||||
|
|
||||||
|
@tool_choice.setter
|
||||||
|
def tool_choice(self, value: Any) -> None:
|
||||||
|
if value is None:
|
||||||
|
self._tool_choice = "auto"
|
||||||
|
return
|
||||||
|
normalized = str(value).strip().lower()
|
||||||
|
allowed = {"auto", "none", "required"}
|
||||||
|
if normalized not in allowed:
|
||||||
|
raise ValueError(f"Unsupported tool choice: {value}")
|
||||||
|
self._tool_choice = normalized
|
||||||
|
|
||||||
def _get_llm_config(self) -> LLMConfig:
|
def _get_llm_config(self) -> LLMConfig:
|
||||||
if self._resolver:
|
if self._resolver:
|
||||||
@ -159,7 +191,7 @@ class DepartmentAgent:
|
|||||||
primary_endpoint,
|
primary_endpoint,
|
||||||
messages,
|
messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto",
|
tool_choice=self._tool_choice,
|
||||||
)
|
)
|
||||||
except LLMError as exc:
|
except LLMError as exc:
|
||||||
LOGGER.warning(
|
LOGGER.warning(
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
import copy
|
||||||
from dataclasses import dataclass, replace
|
from dataclasses import dataclass, replace
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
||||||
|
|
||||||
@ -18,17 +19,27 @@ LOG_EXTRA = {"stage": "decision_env"}
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ParameterSpec:
|
class ParameterSpec:
|
||||||
"""Defines how a scalar action dimension maps to strategy parameters."""
|
"""Defines how an action dimension maps to strategy parameters or behaviors."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
target: str
|
target: str
|
||||||
minimum: float = 0.0
|
minimum: float = 0.0
|
||||||
maximum: float = 1.0
|
maximum: float = 1.0
|
||||||
|
values: Optional[Sequence[Any]] = None
|
||||||
|
|
||||||
def clamp(self, value: float) -> float:
|
def clamp(self, value: float) -> float:
|
||||||
clipped = max(0.0, min(1.0, float(value)))
|
clipped = max(0.0, min(1.0, float(value)))
|
||||||
return self.minimum + clipped * (self.maximum - self.minimum)
|
return self.minimum + clipped * (self.maximum - self.minimum)
|
||||||
|
|
||||||
|
def resolve(self, value: float) -> Any:
|
||||||
|
if self.values is not None:
|
||||||
|
if not self.values:
|
||||||
|
raise ValueError(f"ParameterSpec {self.name} configured with empty values list")
|
||||||
|
clipped = max(0.0, min(1.0, float(value)))
|
||||||
|
index = int(round(clipped * (len(self.values) - 1)))
|
||||||
|
return self.values[index]
|
||||||
|
return self.clamp(value)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EpisodeMetrics:
|
class EpisodeMetrics:
|
||||||
@ -68,6 +79,7 @@ class DecisionEnv:
|
|||||||
self._reward_fn = reward_fn or self._default_reward
|
self._reward_fn = reward_fn or self._default_reward
|
||||||
self._last_metrics: Optional[EpisodeMetrics] = None
|
self._last_metrics: Optional[EpisodeMetrics] = None
|
||||||
self._last_action: Optional[Tuple[float, ...]] = None
|
self._last_action: Optional[Tuple[float, ...]] = None
|
||||||
|
self._last_department_controls: Optional[Dict[str, Dict[str, Any]]] = None
|
||||||
self._episode = 0
|
self._episode = 0
|
||||||
self._disable_departments = bool(disable_departments)
|
self._disable_departments = bool(disable_departments)
|
||||||
|
|
||||||
@ -75,10 +87,15 @@ class DecisionEnv:
|
|||||||
def action_dim(self) -> int:
|
def action_dim(self) -> int:
|
||||||
return len(self._specs)
|
return len(self._specs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_department_controls(self) -> Optional[Dict[str, Dict[str, Any]]]:
|
||||||
|
return self._last_department_controls
|
||||||
|
|
||||||
def reset(self) -> Dict[str, float]:
|
def reset(self) -> Dict[str, float]:
|
||||||
self._episode += 1
|
self._episode += 1
|
||||||
self._last_metrics = None
|
self._last_metrics = None
|
||||||
self._last_action = None
|
self._last_action = None
|
||||||
|
self._last_department_controls = None
|
||||||
return {
|
return {
|
||||||
"episode": float(self._episode),
|
"episode": float(self._episode),
|
||||||
"baseline_return": 0.0,
|
"baseline_return": 0.0,
|
||||||
@ -90,14 +107,24 @@ class DecisionEnv:
|
|||||||
action_array = [float(val) for val in action]
|
action_array = [float(val) for val in action]
|
||||||
self._last_action = tuple(action_array)
|
self._last_action = tuple(action_array)
|
||||||
|
|
||||||
weights = self._build_weights(action_array)
|
weights, department_controls = self._prepare_actions(action_array)
|
||||||
LOGGER.info("episode=%s action=%s weights=%s", self._episode, action_array, weights, extra=LOG_EXTRA)
|
LOGGER.info(
|
||||||
|
"episode=%s action=%s weights=%s controls=%s",
|
||||||
|
self._episode,
|
||||||
|
action_array,
|
||||||
|
weights,
|
||||||
|
department_controls,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
|
||||||
cfg = replace(self._template_cfg)
|
cfg = replace(self._template_cfg)
|
||||||
engine = BacktestEngine(cfg)
|
engine = BacktestEngine(cfg)
|
||||||
engine.weights = weight_map(weights)
|
engine.weights = weight_map(weights)
|
||||||
if self._disable_departments:
|
if self._disable_departments:
|
||||||
engine.department_manager = None
|
engine.department_manager = None
|
||||||
|
applied_controls: Dict[str, Dict[str, Any]] = {}
|
||||||
|
else:
|
||||||
|
applied_controls = self._apply_department_controls(engine, department_controls)
|
||||||
|
|
||||||
self._clear_portfolio_records()
|
self._clear_portfolio_records()
|
||||||
|
|
||||||
@ -135,19 +162,153 @@ class DecisionEnv:
|
|||||||
"risk_events": getattr(result, "risk_events", []),
|
"risk_events": getattr(result, "risk_events", []),
|
||||||
"portfolio_snapshots": snapshots,
|
"portfolio_snapshots": snapshots,
|
||||||
"portfolio_trades": trades_override,
|
"portfolio_trades": trades_override,
|
||||||
|
"department_controls": applied_controls,
|
||||||
}
|
}
|
||||||
|
self._last_department_controls = applied_controls
|
||||||
return observation, reward, True, info
|
return observation, reward, True, info
|
||||||
|
|
||||||
def _build_weights(self, action: Sequence[float]) -> Dict[str, float]:
|
def _prepare_actions(
|
||||||
|
self,
|
||||||
|
action: Sequence[float],
|
||||||
|
) -> Tuple[Dict[str, float], Dict[str, Dict[str, Any]]]:
|
||||||
weights = dict(self._baseline_weights)
|
weights = dict(self._baseline_weights)
|
||||||
|
department_controls: Dict[str, Dict[str, Any]] = {}
|
||||||
for idx, spec in enumerate(self._specs):
|
for idx, spec in enumerate(self._specs):
|
||||||
value = spec.clamp(action[idx])
|
try:
|
||||||
|
resolved = spec.resolve(action[idx])
|
||||||
|
except ValueError as exc:
|
||||||
|
LOGGER.warning("参数 %s 解析失败:%s", spec.name, exc, extra=LOG_EXTRA)
|
||||||
|
continue
|
||||||
if spec.target.startswith("agent_weights."):
|
if spec.target.startswith("agent_weights."):
|
||||||
agent_name = spec.target.split(".", 1)[1]
|
agent_name = spec.target.split(".", 1)[1]
|
||||||
weights[agent_name] = value
|
try:
|
||||||
|
weights[agent_name] = float(resolved)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
LOGGER.debug(
|
||||||
|
"spec %s produced non-numeric weight %s; skipping",
|
||||||
|
spec.name,
|
||||||
|
resolved,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if spec.target.startswith("department."):
|
||||||
|
target_path = spec.target.split(".")[1:]
|
||||||
|
if len(target_path) < 2:
|
||||||
|
LOGGER.debug("未识别的部门目标:%s", spec.target, extra=LOG_EXTRA)
|
||||||
|
continue
|
||||||
|
dept_code = target_path[0]
|
||||||
|
field = ".".join(target_path[1:])
|
||||||
|
dept_controls = department_controls.setdefault(dept_code, {})
|
||||||
|
dept_controls[field] = resolved
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
LOGGER.debug("暂未支持的参数目标:%s", spec.target, extra=LOG_EXTRA)
|
LOGGER.debug("暂未支持的参数目标:%s", spec.target, extra=LOG_EXTRA)
|
||||||
return weights
|
return weights, department_controls
|
||||||
|
|
||||||
|
def _apply_department_controls(
|
||||||
|
self,
|
||||||
|
engine: BacktestEngine,
|
||||||
|
controls: Mapping[str, Mapping[str, Any]],
|
||||||
|
) -> Dict[str, Dict[str, Any]]:
|
||||||
|
manager = getattr(engine, "department_manager", None)
|
||||||
|
if not manager or not getattr(manager, "agents", None):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
applied: Dict[str, Dict[str, Any]] = {}
|
||||||
|
for dept_code, payload in controls.items():
|
||||||
|
agent = manager.agents.get(dept_code)
|
||||||
|
if not agent or not isinstance(payload, Mapping):
|
||||||
|
continue
|
||||||
|
|
||||||
|
applied_fields: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Ensure mutable settings clone to avoid global side-effects
|
||||||
|
try:
|
||||||
|
original_settings = agent.settings
|
||||||
|
cloned_settings = replace(original_settings)
|
||||||
|
cloned_settings.llm = copy.deepcopy(original_settings.llm)
|
||||||
|
agent.settings = cloned_settings
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.warning(
|
||||||
|
"复制部门 %s 配置失败:%s",
|
||||||
|
dept_code,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for raw_field, value in payload.items():
|
||||||
|
field = raw_field.lower()
|
||||||
|
if field == "function_policy":
|
||||||
|
field = "tool_choice"
|
||||||
|
if field in {"prompt", "instruction"}:
|
||||||
|
agent.settings.prompt = str(value)
|
||||||
|
applied_fields[field] = agent.settings.prompt
|
||||||
|
continue
|
||||||
|
if field == "description":
|
||||||
|
agent.settings.description = str(value)
|
||||||
|
applied_fields[field] = agent.settings.description
|
||||||
|
continue
|
||||||
|
if field in {"prompt_template_id", "prompt_template"}:
|
||||||
|
agent.settings.prompt_template_id = str(value)
|
||||||
|
applied_fields["prompt_template_id"] = agent.settings.prompt_template_id
|
||||||
|
continue
|
||||||
|
if field == "prompt_template_version":
|
||||||
|
agent.settings.prompt_template_version = str(value)
|
||||||
|
applied_fields["prompt_template_version"] = agent.settings.prompt_template_version
|
||||||
|
continue
|
||||||
|
if field in {"temperature", "llm.temperature"}:
|
||||||
|
try:
|
||||||
|
temperature = max(0.0, min(2.0, float(value)))
|
||||||
|
agent.settings.llm.primary.temperature = temperature
|
||||||
|
applied_fields["temperature"] = temperature
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
LOGGER.debug(
|
||||||
|
"无效的温度值 %s for %s",
|
||||||
|
value,
|
||||||
|
dept_code,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if field in {"tool_choice", "tool_strategy"}:
|
||||||
|
try:
|
||||||
|
agent.tool_choice = value
|
||||||
|
applied_fields["tool_choice"] = agent.tool_choice
|
||||||
|
except ValueError:
|
||||||
|
LOGGER.debug(
|
||||||
|
"部门 %s 工具策略 %s 无效",
|
||||||
|
dept_code,
|
||||||
|
value,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if field == "max_rounds":
|
||||||
|
try:
|
||||||
|
agent.max_rounds = value
|
||||||
|
applied_fields["max_rounds"] = agent.max_rounds
|
||||||
|
except ValueError:
|
||||||
|
LOGGER.debug(
|
||||||
|
"部门 %s max_rounds %s 无效",
|
||||||
|
dept_code,
|
||||||
|
value,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if field == "prompt_template_override":
|
||||||
|
agent.settings.prompt = str(value)
|
||||||
|
applied_fields["prompt"] = agent.settings.prompt
|
||||||
|
continue
|
||||||
|
LOGGER.debug(
|
||||||
|
"部门 %s 未识别的控制项 %s",
|
||||||
|
dept_code,
|
||||||
|
raw_field,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
|
||||||
|
if applied_fields:
|
||||||
|
applied[dept_code] = applied_fields
|
||||||
|
|
||||||
|
return applied
|
||||||
|
|
||||||
def _compute_metrics(
|
def _compute_metrics(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -252,98 +252,352 @@ class TemplateRegistry:
|
|||||||
DEFAULT_TEMPLATES = {
|
DEFAULT_TEMPLATES = {
|
||||||
"department_base": {
|
"department_base": {
|
||||||
"name": "部门基础模板",
|
"name": "部门基础模板",
|
||||||
"description": "通用的部门分析提示模板",
|
"description": "所有部门通用的审慎分析提示词骨架",
|
||||||
"template": """
|
"template": """
|
||||||
部门名称:{title}
|
部门:{title}
|
||||||
股票代码:{ts_code}
|
股票代码:{ts_code}
|
||||||
交易日:{trade_date}
|
交易日:{trade_date}
|
||||||
|
|
||||||
角色说明:{description}
|
【角色定位】
|
||||||
职责指令:{instruction}
|
- 角色说明:{description}
|
||||||
|
- 行动守则:{instruction}
|
||||||
|
|
||||||
【可用数据范围】
|
【数据边界】
|
||||||
|
- 可用字段:
|
||||||
{data_scope}
|
{data_scope}
|
||||||
|
- 核心特征:
|
||||||
【核心特征】
|
|
||||||
{features}
|
{features}
|
||||||
|
- 市场背景:
|
||||||
【市场背景】
|
|
||||||
{market_snapshot}
|
{market_snapshot}
|
||||||
|
- 追加数据:
|
||||||
【追加数据】
|
|
||||||
{supplements}
|
{supplements}
|
||||||
|
|
||||||
请基于以上数据给出该部门对当前股票的操作建议。输出必须是 JSON,字段如下:
|
【分析步骤】
|
||||||
|
1. 判断信息是否充分,如不充分,请说明缺口并优先调用工具 `fetch_data`(仅限 `daily`、`daily_basic`)。
|
||||||
|
2. 梳理 2-3 个关键支撑信号与潜在风险,确保基于提供的数据。
|
||||||
|
3. 结合量化证据与限制条件,给出操作建议和信心来源,避免主观臆测。
|
||||||
|
|
||||||
|
【输出要求】
|
||||||
|
仅返回一个 JSON 对象,不要添加额外文本:
|
||||||
{{
|
{{
|
||||||
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
|
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
|
||||||
"confidence": 0-1 之间的小数,表示信心,
|
"confidence": 0-1 之间的小数,
|
||||||
"summary": "一句话概括理由",
|
"summary": "一句话结论",
|
||||||
"signals": ["详细要点", "..."],
|
"signals": ["关键支撑要点", "..."],
|
||||||
"risks": ["风险点", "..."]
|
"risks": ["关键风险要点", "..."]
|
||||||
}}
|
}}
|
||||||
|
如需说明未完成的数据请求,请在 `risks` 或 `signals` 中明确。
|
||||||
如需额外数据,请调用工具 `fetch_data`,仅支持请求 `daily` 或 `daily_basic` 表。
|
|
||||||
请严格返回单个 JSON 对象,不要添加额外文本。
|
|
||||||
""",
|
""",
|
||||||
"variables": [
|
"variables": [
|
||||||
"title", "ts_code", "trade_date", "description", "instruction",
|
"title",
|
||||||
"data_scope", "features", "market_snapshot", "supplements"
|
"ts_code",
|
||||||
|
"trade_date",
|
||||||
|
"description",
|
||||||
|
"instruction",
|
||||||
|
"data_scope",
|
||||||
|
"features",
|
||||||
|
"market_snapshot",
|
||||||
|
"supplements",
|
||||||
],
|
],
|
||||||
"required_context": [
|
"required_context": [
|
||||||
"ts_code", "trade_date", "features", "market_snapshot"
|
"ts_code",
|
||||||
|
"trade_date",
|
||||||
|
"features",
|
||||||
|
"market_snapshot",
|
||||||
],
|
],
|
||||||
"validation_rules": [
|
"metadata": {
|
||||||
"len(features) > 0",
|
"category": "department",
|
||||||
"len(market_snapshot) > 0"
|
"preset": "base",
|
||||||
]
|
},
|
||||||
},
|
},
|
||||||
"momentum_dept": {
|
"momentum_dept": {
|
||||||
"name": "动量研究部门",
|
"name": "动量研究部门模板",
|
||||||
"description": "专注于动量因子分析的部门模板",
|
"description": "围绕价格与量能动量的决策提示",
|
||||||
"template": """
|
"template": """
|
||||||
部门名称:动量研究部门
|
部门:动量研究部门
|
||||||
股票代码:{ts_code}
|
股票代码:{ts_code}
|
||||||
交易日:{trade_date}
|
交易日:{trade_date}
|
||||||
|
|
||||||
角色说明:专注于分析股票价格动量、成交量动量和技术指标动量
|
【角色定位】
|
||||||
职责指令:重点关注以下方面:
|
- 专注价格动量、成交量共振与技术指标背离。
|
||||||
1. 价格趋势强度和持续性
|
- 保持纪律,识别趋势延续与反转风险。
|
||||||
2. 成交量配合度
|
|
||||||
3. 技术指标背离
|
|
||||||
|
|
||||||
【可用数据范围】
|
【研究重点】
|
||||||
|
1. 多时间窗口动量是否同向?
|
||||||
|
2. 成交量是否验证价格走势?
|
||||||
|
3. 是否出现过热或背离信号?
|
||||||
|
|
||||||
|
【数据边界】
|
||||||
|
- 可用字段:
|
||||||
{data_scope}
|
{data_scope}
|
||||||
|
- 动量特征:
|
||||||
【动量特征】
|
|
||||||
{features}
|
{features}
|
||||||
|
- 市场背景:
|
||||||
【市场背景】
|
|
||||||
{market_snapshot}
|
{market_snapshot}
|
||||||
|
- 追加数据:
|
||||||
【追加数据】
|
|
||||||
{supplements}
|
{supplements}
|
||||||
|
|
||||||
请基于以上数据进行动量分析并给出操作建议。输出必须是 JSON,字段如下:
|
请沿用【部门基础模板】的分析步骤与输出要求,重点量化趋势动能和量价配合度。
|
||||||
{{
|
|
||||||
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
|
|
||||||
"confidence": 0-1 之间的小数,表示信心,
|
|
||||||
"summary": "一句话概括动量分析结论",
|
|
||||||
"signals": ["动量信号要点", "..."],
|
|
||||||
"risks": ["动量风险点", "..."]
|
|
||||||
}}
|
|
||||||
""",
|
""",
|
||||||
"variables": [
|
"variables": [
|
||||||
"ts_code", "trade_date", "data_scope",
|
"ts_code",
|
||||||
"features", "market_snapshot", "supplements"
|
"trade_date",
|
||||||
|
"data_scope",
|
||||||
|
"features",
|
||||||
|
"market_snapshot",
|
||||||
|
"supplements",
|
||||||
],
|
],
|
||||||
"required_context": [
|
"required_context": [
|
||||||
"ts_code", "trade_date", "features", "market_snapshot"
|
"ts_code",
|
||||||
|
"trade_date",
|
||||||
|
"features",
|
||||||
|
"market_snapshot",
|
||||||
],
|
],
|
||||||
"validation_rules": [
|
"metadata": {
|
||||||
"len(features) > 0",
|
"category": "department",
|
||||||
"'momentum' in ' '.join(features.keys()).lower()"
|
"preset": "momentum",
|
||||||
]
|
},
|
||||||
}
|
},
|
||||||
|
"value_dept": {
|
||||||
|
"name": "价值评估部门模板",
|
||||||
|
"description": "衡量估值与盈利质量的提示词",
|
||||||
|
"template": """
|
||||||
|
部门:价值评估部门
|
||||||
|
股票代码:{ts_code}
|
||||||
|
交易日:{trade_date}
|
||||||
|
|
||||||
|
【角色定位】
|
||||||
|
- 关注估值分位、盈利质量与安全边际。
|
||||||
|
- 从中期配置角度评价当前价格的性价比。
|
||||||
|
|
||||||
|
【研究重点】
|
||||||
|
1. 历史及同业视角的估值位置。
|
||||||
|
2. 盈利与分红的可持续性。
|
||||||
|
3. 潜在的估值修复催化或压制因素。
|
||||||
|
|
||||||
|
【数据边界】
|
||||||
|
- 可用字段:
|
||||||
|
{data_scope}
|
||||||
|
- 估值与质量特征:
|
||||||
|
{features}
|
||||||
|
- 市场背景:
|
||||||
|
{market_snapshot}
|
||||||
|
- 追加数据:
|
||||||
|
{supplements}
|
||||||
|
|
||||||
|
请按照【部门基础模板】的分析步骤输出结论,并明确估值安全边际来源。
|
||||||
|
""",
|
||||||
|
"variables": [
|
||||||
|
"ts_code",
|
||||||
|
"trade_date",
|
||||||
|
"data_scope",
|
||||||
|
"features",
|
||||||
|
"market_snapshot",
|
||||||
|
"supplements",
|
||||||
|
],
|
||||||
|
"required_context": [
|
||||||
|
"ts_code",
|
||||||
|
"trade_date",
|
||||||
|
"features",
|
||||||
|
"market_snapshot",
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"category": "department",
|
||||||
|
"preset": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"news_dept": {
|
||||||
|
"name": "新闻情绪部门模板",
|
||||||
|
"description": "针对舆情热度与事件影响的提示词",
|
||||||
|
"template": """
|
||||||
|
部门:新闻情绪部门
|
||||||
|
股票代码:{ts_code}
|
||||||
|
交易日:{trade_date}
|
||||||
|
|
||||||
|
【角色定位】
|
||||||
|
- 监控舆情热度、事件驱动与短期情绪。
|
||||||
|
- 评估新闻对价格波动的正负面影响。
|
||||||
|
|
||||||
|
【研究重点】
|
||||||
|
1. 新闻情绪是否集中且持续?
|
||||||
|
2. 主题与行情是否匹配?
|
||||||
|
3. 情绪驱动的风险敞口。
|
||||||
|
|
||||||
|
【数据边界】
|
||||||
|
- 可用字段:
|
||||||
|
{data_scope}
|
||||||
|
- 舆情特征:
|
||||||
|
{features}
|
||||||
|
- 市场背景:
|
||||||
|
{market_snapshot}
|
||||||
|
- 追加数据:
|
||||||
|
{supplements}
|
||||||
|
|
||||||
|
请遵循【部门基础模板】的分析步骤,突出情绪驱动的力度与时效性。
|
||||||
|
""",
|
||||||
|
"variables": [
|
||||||
|
"ts_code",
|
||||||
|
"trade_date",
|
||||||
|
"data_scope",
|
||||||
|
"features",
|
||||||
|
"market_snapshot",
|
||||||
|
"supplements",
|
||||||
|
],
|
||||||
|
"required_context": [
|
||||||
|
"ts_code",
|
||||||
|
"trade_date",
|
||||||
|
"features",
|
||||||
|
"market_snapshot",
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"category": "department",
|
||||||
|
"preset": "news",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"liquidity_dept": {
|
||||||
|
"name": "流动性评估部门模板",
|
||||||
|
"description": "衡量成交活跃度与执行成本的提示词",
|
||||||
|
"template": """
|
||||||
|
部门:流动性评估部门
|
||||||
|
股票代码:{ts_code}
|
||||||
|
交易日:{trade_date}
|
||||||
|
|
||||||
|
【角色定位】
|
||||||
|
- 评估成交活跃度、交易成本与可执行性。
|
||||||
|
- 提醒潜在的流动性风险与仓位限制。
|
||||||
|
|
||||||
|
【研究重点】
|
||||||
|
1. 当前成交量与历史均值的对比。
|
||||||
|
2. 价量限制(涨跌停、停牌等)对执行的影响。
|
||||||
|
3. 预估滑点与转手难度。
|
||||||
|
|
||||||
|
【数据边界】
|
||||||
|
- 可用字段:
|
||||||
|
{data_scope}
|
||||||
|
- 流动性特征:
|
||||||
|
{features}
|
||||||
|
- 市场背景:
|
||||||
|
{market_snapshot}
|
||||||
|
- 追加数据:
|
||||||
|
{supplements}
|
||||||
|
|
||||||
|
请遵循【部门基础模板】的分析步骤,重点描述执行可行性与仓位建议。
|
||||||
|
""",
|
||||||
|
"variables": [
|
||||||
|
"ts_code",
|
||||||
|
"trade_date",
|
||||||
|
"data_scope",
|
||||||
|
"features",
|
||||||
|
"market_snapshot",
|
||||||
|
"supplements",
|
||||||
|
],
|
||||||
|
"required_context": [
|
||||||
|
"ts_code",
|
||||||
|
"trade_date",
|
||||||
|
"features",
|
||||||
|
"market_snapshot",
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"category": "department",
|
||||||
|
"preset": "liquidity",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"macro_dept": {
|
||||||
|
"name": "宏观研究部门模板",
|
||||||
|
"description": "宏观与行业景气度分析提示词",
|
||||||
|
"template": """
|
||||||
|
部门:宏观研究部门
|
||||||
|
股票代码:{ts_code}
|
||||||
|
交易日:{trade_date}
|
||||||
|
|
||||||
|
【角色定位】
|
||||||
|
- 追踪宏观周期、行业景气与相对强弱。
|
||||||
|
- 评估宏观事件对该标的的方向性影响。
|
||||||
|
|
||||||
|
【研究重点】
|
||||||
|
1. 行业相对大盘的表现与热点程度。
|
||||||
|
2. 宏观/政策事件对行业或标的的指引。
|
||||||
|
3. 需警惕的宏观风险与流动性环境。
|
||||||
|
|
||||||
|
【数据边界】
|
||||||
|
- 可用字段:
|
||||||
|
{data_scope}
|
||||||
|
- 宏观特征:
|
||||||
|
{features}
|
||||||
|
- 市场背景:
|
||||||
|
{market_snapshot}
|
||||||
|
- 追加数据:
|
||||||
|
{supplements}
|
||||||
|
|
||||||
|
请执行【部门基础模板】的分析步骤,并输出宏观驱动的信号与风险。
|
||||||
|
""",
|
||||||
|
"variables": [
|
||||||
|
"ts_code",
|
||||||
|
"trade_date",
|
||||||
|
"data_scope",
|
||||||
|
"features",
|
||||||
|
"market_snapshot",
|
||||||
|
"supplements",
|
||||||
|
],
|
||||||
|
"required_context": [
|
||||||
|
"ts_code",
|
||||||
|
"trade_date",
|
||||||
|
"features",
|
||||||
|
"market_snapshot",
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"category": "department",
|
||||||
|
"preset": "macro",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"risk_dept": {
|
||||||
|
"name": "风险控制部门模板",
|
||||||
|
"description": "识别极端风险与限制条件的提示词",
|
||||||
|
"template": """
|
||||||
|
部门:风险控制部门
|
||||||
|
股票代码:{ts_code}
|
||||||
|
交易日:{trade_date}
|
||||||
|
|
||||||
|
【角色定位】
|
||||||
|
- 防范停牌、涨跌停、仓位与合规限制。
|
||||||
|
- 必要时对高风险决策行使否决权。
|
||||||
|
|
||||||
|
【研究重点】
|
||||||
|
1. 交易限制或异常波动情况。
|
||||||
|
2. 仓位、集中度或风险指标是否触顶。
|
||||||
|
3. 潜在的黑天鹅或执行障碍。
|
||||||
|
|
||||||
|
【数据边界】
|
||||||
|
- 可用字段:
|
||||||
|
{data_scope}
|
||||||
|
- 风险特征:
|
||||||
|
{features}
|
||||||
|
- 市场背景:
|
||||||
|
{market_snapshot}
|
||||||
|
- 追加数据:
|
||||||
|
{supplements}
|
||||||
|
|
||||||
|
请按照【部门基础模板】的分析步骤,必要时明确阻止交易的理由。
|
||||||
|
""",
|
||||||
|
"variables": [
|
||||||
|
"ts_code",
|
||||||
|
"trade_date",
|
||||||
|
"data_scope",
|
||||||
|
"features",
|
||||||
|
"market_snapshot",
|
||||||
|
"supplements",
|
||||||
|
],
|
||||||
|
"required_context": [
|
||||||
|
"ts_code",
|
||||||
|
"trade_date",
|
||||||
|
"features",
|
||||||
|
"market_snapshot",
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"category": "department",
|
||||||
|
"preset": "risk",
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -331,6 +331,7 @@ def render_backtest_review() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
specs: List[ParameterSpec] = []
|
specs: List[ParameterSpec] = []
|
||||||
|
spec_labels: List[str] = []
|
||||||
action_values: List[float] = []
|
action_values: List[float] = []
|
||||||
range_valid = True
|
range_valid = True
|
||||||
for idx, agent_name in enumerate(selected_agents):
|
for idx, agent_name in enumerate(selected_agents):
|
||||||
@ -374,15 +375,111 @@ def render_backtest_review() -> None:
|
|||||||
maximum=max_val,
|
maximum=max_val,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
spec_labels.append(f"agent:{agent_name}")
|
||||||
action_values.append(action_val)
|
action_values.append(action_val)
|
||||||
|
|
||||||
|
controls_valid = True
|
||||||
|
with st.expander("部门 LLM 参数", expanded=False):
|
||||||
|
dept_codes = sorted(app_cfg.departments.keys())
|
||||||
|
if not dept_codes:
|
||||||
|
st.caption("当前未配置部门。")
|
||||||
|
else:
|
||||||
|
selected_departments = st.multiselect(
|
||||||
|
"选择需要调整的部门",
|
||||||
|
dept_codes,
|
||||||
|
default=[],
|
||||||
|
key="decision_env_departments",
|
||||||
|
)
|
||||||
|
tool_policy_values = ["auto", "none", "required"]
|
||||||
|
for dept_code in selected_departments:
|
||||||
|
settings = app_cfg.departments.get(dept_code)
|
||||||
|
if not settings:
|
||||||
|
continue
|
||||||
|
st.subheader(f"部门:{settings.title or dept_code}")
|
||||||
|
base_temp = 0.2
|
||||||
|
if settings.llm and settings.llm.primary and settings.llm.primary.temperature is not None:
|
||||||
|
base_temp = float(settings.llm.primary.temperature)
|
||||||
|
prefix = f"decision_env_dept_{dept_code}"
|
||||||
|
col_tmin, col_tmax, col_tslider = st.columns([1, 1, 2])
|
||||||
|
temp_min = col_tmin.number_input(
|
||||||
|
"温度最小值",
|
||||||
|
min_value=0.0,
|
||||||
|
max_value=2.0,
|
||||||
|
value=max(0.0, base_temp - 0.3),
|
||||||
|
step=0.05,
|
||||||
|
key=f"{prefix}_temp_min",
|
||||||
|
)
|
||||||
|
temp_max = col_tmax.number_input(
|
||||||
|
"温度最大值",
|
||||||
|
min_value=0.0,
|
||||||
|
max_value=2.0,
|
||||||
|
value=min(2.0, base_temp + 0.3),
|
||||||
|
step=0.05,
|
||||||
|
key=f"{prefix}_temp_max",
|
||||||
|
)
|
||||||
|
if temp_max <= temp_min:
|
||||||
|
controls_valid = False
|
||||||
|
st.warning("温度最大值必须大于最小值。")
|
||||||
|
temp_max = min(2.0, temp_min + 0.01)
|
||||||
|
span = temp_max - temp_min
|
||||||
|
if span <= 0:
|
||||||
|
ratio_default = 0.0
|
||||||
|
else:
|
||||||
|
clamped = min(max(base_temp, temp_min), temp_max)
|
||||||
|
ratio_default = (clamped - temp_min) / span
|
||||||
|
temp_action = col_tslider.slider(
|
||||||
|
"动作值(映射至温度区间)",
|
||||||
|
min_value=0.0,
|
||||||
|
max_value=1.0,
|
||||||
|
value=float(ratio_default),
|
||||||
|
step=0.01,
|
||||||
|
key=f"{prefix}_temp_action",
|
||||||
|
)
|
||||||
|
specs.append(
|
||||||
|
ParameterSpec(
|
||||||
|
name=f"dept_temperature_{dept_code}",
|
||||||
|
target=f"department.{dept_code}.temperature",
|
||||||
|
minimum=temp_min,
|
||||||
|
maximum=temp_max,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
spec_labels.append(f"department:{dept_code}:temperature")
|
||||||
|
action_values.append(temp_action)
|
||||||
|
|
||||||
|
col_tool, col_hint = st.columns([1, 2])
|
||||||
|
tool_choice = col_tool.selectbox(
|
||||||
|
"函数调用策略",
|
||||||
|
tool_policy_values,
|
||||||
|
index=tool_policy_values.index("auto"),
|
||||||
|
key=f"{prefix}_tool_choice",
|
||||||
|
)
|
||||||
|
col_hint.caption("映射提示:0→auto,0.5→none,1→required。")
|
||||||
|
if len(tool_policy_values) > 1:
|
||||||
|
tool_value = tool_policy_values.index(tool_choice) / (len(tool_policy_values) - 1)
|
||||||
|
else:
|
||||||
|
tool_value = 0.0
|
||||||
|
specs.append(
|
||||||
|
ParameterSpec(
|
||||||
|
name=f"dept_tool_{dept_code}",
|
||||||
|
target=f"department.{dept_code}.function_policy",
|
||||||
|
values=tool_policy_values,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
spec_labels.append(f"department:{dept_code}:tool_choice")
|
||||||
|
action_values.append(tool_value)
|
||||||
|
|
||||||
|
if specs:
|
||||||
|
st.caption("动作维度顺序:" + ",".join(spec_labels))
|
||||||
|
|
||||||
run_decision_env = st.button("执行单次调参", key="run_decision_env_button")
|
run_decision_env = st.button("执行单次调参", key="run_decision_env_button")
|
||||||
just_finished_single = False
|
just_finished_single = False
|
||||||
if run_decision_env:
|
if run_decision_env:
|
||||||
if not selected_agents:
|
if not specs:
|
||||||
st.warning("请至少选择一个代理进行调参。")
|
st.warning("请至少配置一个动作维度(代理或部门参数)。")
|
||||||
elif not range_valid:
|
elif selected_agents and not range_valid:
|
||||||
st.error("请确保所有代理的最大权重大于最小权重。")
|
st.error("请确保所有代理的最大权重大于最小权重。")
|
||||||
|
elif not controls_valid:
|
||||||
|
st.error("请修正部门参数的取值范围。")
|
||||||
else:
|
else:
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
"离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s",
|
"离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s",
|
||||||
@ -448,11 +545,11 @@ def render_backtest_review() -> None:
|
|||||||
resolved_experiment_id = experiment_id or str(uuid.uuid4())
|
resolved_experiment_id = experiment_id or str(uuid.uuid4())
|
||||||
resolved_strategy = strategy_label or "DecisionEnv"
|
resolved_strategy = strategy_label or "DecisionEnv"
|
||||||
action_payload = {
|
action_payload = {
|
||||||
name: value
|
label: value for label, value in zip(spec_labels, action_values)
|
||||||
for name, value in zip(selected_agents, action_values)
|
|
||||||
}
|
}
|
||||||
metrics_payload = dict(observation)
|
metrics_payload = dict(observation)
|
||||||
metrics_payload["reward"] = reward
|
metrics_payload["reward"] = reward
|
||||||
|
metrics_payload["department_controls"] = info.get("department_controls")
|
||||||
log_success = False
|
log_success = False
|
||||||
try:
|
try:
|
||||||
log_tuning_result(
|
log_tuning_result(
|
||||||
@ -477,12 +574,14 @@ def render_backtest_review() -> None:
|
|||||||
"observation": dict(observation),
|
"observation": dict(observation),
|
||||||
"reward": float(reward),
|
"reward": float(reward),
|
||||||
"weights": info.get("weights", {}),
|
"weights": info.get("weights", {}),
|
||||||
|
"department_controls": info.get("department_controls"),
|
||||||
|
"actions": action_payload,
|
||||||
"nav_series": info.get("nav_series"),
|
"nav_series": info.get("nav_series"),
|
||||||
"trades": info.get("trades"),
|
"trades": info.get("trades"),
|
||||||
"portfolio_snapshots": info.get("portfolio_snapshots"),
|
"portfolio_snapshots": info.get("portfolio_snapshots"),
|
||||||
"portfolio_trades": info.get("portfolio_trades"),
|
"portfolio_trades": info.get("portfolio_trades"),
|
||||||
"risk_breakdown": info.get("risk_breakdown"),
|
"risk_breakdown": info.get("risk_breakdown"),
|
||||||
"selected_agents": list(selected_agents),
|
"spec_labels": list(spec_labels),
|
||||||
"action_values": list(action_values),
|
"action_values": list(action_values),
|
||||||
"experiment_id": resolved_experiment_id,
|
"experiment_id": resolved_experiment_id,
|
||||||
"strategy_label": resolved_strategy,
|
"strategy_label": resolved_strategy,
|
||||||
@ -562,6 +661,16 @@ def render_backtest_review() -> None:
|
|||||||
with st.expander("风险事件统计", expanded=False):
|
with st.expander("风险事件统计", expanded=False):
|
||||||
st.json(risk_breakdown)
|
st.json(risk_breakdown)
|
||||||
|
|
||||||
|
department_info = single_result.get("department_controls") or {}
|
||||||
|
if department_info:
|
||||||
|
with st.expander("部门控制参数", expanded=False):
|
||||||
|
st.json(department_info)
|
||||||
|
|
||||||
|
action_snapshot = single_result.get("actions") or {}
|
||||||
|
if action_snapshot:
|
||||||
|
with st.expander("动作明细", expanded=False):
|
||||||
|
st.json(action_snapshot)
|
||||||
|
|
||||||
if st.button("清除单次调参结果", key="clear_decision_env_single"):
|
if st.button("清除单次调参结果", key="clear_decision_env_single"):
|
||||||
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
|
||||||
st.success("已清除单次调参结果缓存。")
|
st.success("已清除单次调参结果缓存。")
|
||||||
@ -584,10 +693,12 @@ def render_backtest_review() -> None:
|
|||||||
run_batch = st.button("批量执行调参", key="run_decision_env_batch")
|
run_batch = st.button("批量执行调参", key="run_decision_env_batch")
|
||||||
batch_just_ran = False
|
batch_just_ran = False
|
||||||
if run_batch:
|
if run_batch:
|
||||||
if not selected_agents:
|
if not specs:
|
||||||
st.warning("请先选择调参代理。")
|
st.warning("请至少配置一个动作维度。")
|
||||||
elif not range_valid:
|
elif selected_agents and not range_valid:
|
||||||
st.error("请确保所有代理的最大权重大于最小权重。")
|
st.error("请确保所有代理的最大权重大于最小权重。")
|
||||||
|
elif not controls_valid:
|
||||||
|
st.error("请修正部门参数的取值范围。")
|
||||||
else:
|
else:
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
"离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s",
|
"离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s",
|
||||||
@ -693,11 +804,12 @@ def render_backtest_review() -> None:
|
|||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
action_payload = {
|
action_payload = {
|
||||||
name: value
|
label: value
|
||||||
for name, value in zip(selected_agents, action_vals)
|
for label, value in zip(spec_labels, action_vals)
|
||||||
}
|
}
|
||||||
metrics_payload = dict(observation)
|
metrics_payload = dict(observation)
|
||||||
metrics_payload["reward"] = reward
|
metrics_payload["reward"] = reward
|
||||||
|
metrics_payload["department_controls"] = info.get("department_controls")
|
||||||
weights_payload = info.get("weights", {})
|
weights_payload = info.get("weights", {})
|
||||||
try:
|
try:
|
||||||
log_tuning_result(
|
log_tuning_result(
|
||||||
@ -713,13 +825,14 @@ def render_backtest_review() -> None:
|
|||||||
results.append(
|
results.append(
|
||||||
{
|
{
|
||||||
"序号": idx,
|
"序号": idx,
|
||||||
"动作": action_vals,
|
"动作": action_payload,
|
||||||
"状态": "ok",
|
"状态": "ok",
|
||||||
"总收益": observation.get("total_return", 0.0),
|
"总收益": observation.get("total_return", 0.0),
|
||||||
"最大回撤": observation.get("max_drawdown", 0.0),
|
"最大回撤": observation.get("max_drawdown", 0.0),
|
||||||
"波动率": observation.get("volatility", 0.0),
|
"波动率": observation.get("volatility", 0.0),
|
||||||
"奖励": reward,
|
"奖励": reward,
|
||||||
"权重": weights_payload,
|
"权重": weights_payload,
|
||||||
|
"部门控制": info.get("department_controls"),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = {
|
st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = {
|
||||||
|
|||||||
@ -18,7 +18,7 @@
|
|||||||
- 围绕动量、估值、流动性等核心信号扩展轻量高质量因子集,全部由程序生成,满足端到端自动决策需求。
|
- 围绕动量、估值、流动性等核心信号扩展轻量高质量因子集,全部由程序生成,满足端到端自动决策需求。
|
||||||
|
|
||||||
## 3. 决策优化与强化学习
|
## 3. 决策优化与强化学习
|
||||||
- 扩展 `DecisionEnv` 的动作空间(提示版本、部门温度、function 调用策略等),不仅限于代理权重调节。
|
- ✅ 扩展 `DecisionEnv` 的动作空间(提示版本、部门温度、function 调用策略等),不仅限于代理权重调节。
|
||||||
- 引入 Bandit / 贝叶斯优化或 RL 算法探索动作空间,并将 `portfolio_snapshots`、`portfolio_trades` 指标纳入奖励约束。
|
- 引入 Bandit / 贝叶斯优化或 RL 算法探索动作空间,并将 `portfolio_snapshots`、`portfolio_trades` 指标纳入奖励约束。
|
||||||
- 构建实时持仓/成交数据写入链路,使线上监控与离线调参共用同一数据源。
|
- 构建实时持仓/成交数据写入链路,使线上监控与离线调参共用同一数据源。
|
||||||
- 借鉴 TradingAgents-CN 的做法:拆分环境与策略、提供训练脚本/配置,并输出丰富的评估指标(如 Sharpe、Sortino、基准对比)。
|
- 借鉴 TradingAgents-CN 的做法:拆分环境与策略、提供训练脚本/配置,并输出丰富的评估指标(如 Sharpe、Sortino、基准对比)。
|
||||||
|
|||||||
@ -7,13 +7,58 @@ import pytest
|
|||||||
|
|
||||||
from app.backtest.decision_env import DecisionEnv, EpisodeMetrics, ParameterSpec
|
from app.backtest.decision_env import DecisionEnv, EpisodeMetrics, ParameterSpec
|
||||||
from app.backtest.engine import BacktestResult, BtConfig
|
from app.backtest.engine import BacktestResult, BtConfig
|
||||||
|
from app.utils.config import DepartmentSettings, LLMConfig, LLMEndpoint
|
||||||
|
|
||||||
|
|
||||||
|
class _StubDepartmentAgent:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._tool_choice = "auto"
|
||||||
|
self._max_rounds = 3
|
||||||
|
endpoint = LLMEndpoint(provider="openai", model="mock", temperature=0.2)
|
||||||
|
self.settings = DepartmentSettings(
|
||||||
|
code="momentum",
|
||||||
|
title="Momentum",
|
||||||
|
description="baseline",
|
||||||
|
prompt="baseline",
|
||||||
|
llm=LLMConfig(primary=endpoint),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tool_choice(self) -> str:
|
||||||
|
return self._tool_choice
|
||||||
|
|
||||||
|
@tool_choice.setter
|
||||||
|
def tool_choice(self, value) -> None:
|
||||||
|
normalized = str(value).strip().lower()
|
||||||
|
if normalized not in {"auto", "none", "required"}:
|
||||||
|
raise ValueError("invalid tool choice")
|
||||||
|
self._tool_choice = normalized
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_rounds(self) -> int:
|
||||||
|
return self._max_rounds
|
||||||
|
|
||||||
|
@max_rounds.setter
|
||||||
|
def max_rounds(self, value) -> None:
|
||||||
|
numeric = int(round(float(value)))
|
||||||
|
if numeric < 1:
|
||||||
|
numeric = 1
|
||||||
|
if numeric > 6:
|
||||||
|
numeric = 6
|
||||||
|
self._max_rounds = numeric
|
||||||
|
|
||||||
|
|
||||||
|
class _StubManager:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.agents = {"momentum": _StubDepartmentAgent()}
|
||||||
|
|
||||||
|
|
||||||
class _StubEngine:
|
class _StubEngine:
|
||||||
def __init__(self, cfg: BtConfig) -> None: # noqa: D401
|
def __init__(self, cfg: BtConfig) -> None: # noqa: D401
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.weights = {}
|
self.weights = {}
|
||||||
self.department_manager = None
|
self.department_manager = _StubManager()
|
||||||
|
_StubEngine.last_instance = self
|
||||||
|
|
||||||
def run(self) -> BacktestResult:
|
def run(self) -> BacktestResult:
|
||||||
result = BacktestResult()
|
result = BacktestResult()
|
||||||
@ -53,6 +98,9 @@ class _StubEngine:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
_StubEngine.last_instance: _StubEngine | None = None
|
||||||
|
|
||||||
|
|
||||||
def test_decision_env_returns_risk_metrics(monkeypatch):
|
def test_decision_env_returns_risk_metrics(monkeypatch):
|
||||||
cfg = BtConfig(
|
cfg = BtConfig(
|
||||||
id="stub",
|
id="stub",
|
||||||
@ -96,3 +144,68 @@ def test_default_reward_penalizes_metrics():
|
|||||||
)
|
)
|
||||||
reward = DecisionEnv._default_reward(metrics)
|
reward = DecisionEnv._default_reward(metrics)
|
||||||
assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.1 * 0.3))
|
assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.1 * 0.3))
|
||||||
|
|
||||||
|
|
||||||
|
def test_decision_env_department_controls(monkeypatch):
|
||||||
|
cfg = BtConfig(
|
||||||
|
id="stub",
|
||||||
|
name="stub",
|
||||||
|
start_date=date(2025, 1, 10),
|
||||||
|
end_date=date(2025, 1, 10),
|
||||||
|
universe=["000001.SZ"],
|
||||||
|
params={},
|
||||||
|
)
|
||||||
|
specs = [
|
||||||
|
ParameterSpec(name="w_mom", target="agent_weights.A_mom", minimum=0.0, maximum=1.0),
|
||||||
|
ParameterSpec(
|
||||||
|
name="dept_prompt",
|
||||||
|
target="department.momentum.prompt",
|
||||||
|
values=["baseline", "aggressive"],
|
||||||
|
),
|
||||||
|
ParameterSpec(
|
||||||
|
name="dept_temp",
|
||||||
|
target="department.momentum.temperature",
|
||||||
|
minimum=0.1,
|
||||||
|
maximum=0.9,
|
||||||
|
),
|
||||||
|
ParameterSpec(
|
||||||
|
name="dept_tool",
|
||||||
|
target="department.momentum.function_policy",
|
||||||
|
values=["none", "auto", "required"],
|
||||||
|
),
|
||||||
|
ParameterSpec(
|
||||||
|
name="dept_rounds",
|
||||||
|
target="department.momentum.max_rounds",
|
||||||
|
minimum=1,
|
||||||
|
maximum=5,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
env = DecisionEnv(bt_config=cfg, parameter_specs=specs, baseline_weights={"A_mom": 0.5})
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.backtest.decision_env.BacktestEngine", _StubEngine)
|
||||||
|
monkeypatch.setattr(DecisionEnv, "_clear_portfolio_records", lambda self: None)
|
||||||
|
monkeypatch.setattr(DecisionEnv, "_fetch_portfolio_records", lambda self: ([], []))
|
||||||
|
|
||||||
|
obs, reward, done, info = env.step([0.3, 1.0, 0.75, 0.0, 1.0])
|
||||||
|
|
||||||
|
assert done is True
|
||||||
|
assert obs["total_return"] == pytest.approx(0.0)
|
||||||
|
|
||||||
|
controls = info["department_controls"]
|
||||||
|
assert "momentum" in controls
|
||||||
|
momentum_ctrl = controls["momentum"]
|
||||||
|
assert momentum_ctrl["prompt"] == "aggressive"
|
||||||
|
assert momentum_ctrl["temperature"] == pytest.approx(0.7, abs=1e-6)
|
||||||
|
assert momentum_ctrl["tool_choice"] == "none"
|
||||||
|
assert momentum_ctrl["max_rounds"] == 5
|
||||||
|
|
||||||
|
assert env.last_department_controls == controls
|
||||||
|
|
||||||
|
engine = _StubEngine.last_instance
|
||||||
|
assert engine is not None
|
||||||
|
agent = engine.department_manager.agents["momentum"]
|
||||||
|
assert agent.settings.prompt == "aggressive"
|
||||||
|
assert agent.settings.llm.primary.temperature == pytest.approx(0.7, abs=1e-6)
|
||||||
|
assert agent.tool_choice == "none"
|
||||||
|
assert agent.max_rounds == 5
|
||||||
|
|||||||
@ -169,6 +169,12 @@ def test_default_templates():
|
|||||||
assert momentum is not None
|
assert momentum is not None
|
||||||
assert "动量研究部门" in momentum.name
|
assert "动量研究部门" in momentum.name
|
||||||
|
|
||||||
|
assert TemplateRegistry.get("value_dept") is not None
|
||||||
|
assert TemplateRegistry.get("news_dept") is not None
|
||||||
|
assert TemplateRegistry.get("liquidity_dept") is not None
|
||||||
|
assert TemplateRegistry.get("macro_dept") is not None
|
||||||
|
assert TemplateRegistry.get("risk_dept") is not None
|
||||||
|
|
||||||
# Validate template content
|
# Validate template content
|
||||||
assert all("{" + var + "}" in dept_base.template for var in dept_base.variables)
|
assert all("{" + var + "}" in dept_base.template for var in dept_base.variables)
|
||||||
assert all("{" + var + "}" in momentum.template for var in momentum.variables)
|
assert all("{" + var + "}" in momentum.template for var in momentum.variables)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user