This commit is contained in:
sam 2025-10-06 13:21:43 +08:00
parent 3351ca6b5a
commit d0a0340db6
7 changed files with 758 additions and 79 deletions

View File

@ -93,6 +93,38 @@ class DepartmentAgent:
self._resolver = resolver
self._broker = DataBroker()
self._max_rounds = 3
self._tool_choice = "auto"
@property
def max_rounds(self) -> int:
return self._max_rounds
@max_rounds.setter
def max_rounds(self, value: Any) -> None:
try:
numeric = int(round(float(value)))
except (TypeError, ValueError):
raise ValueError("max_rounds must be numeric") from None
if numeric < 1:
numeric = 1
if numeric > 6:
numeric = 6
self._max_rounds = numeric
@property
def tool_choice(self) -> str:
return self._tool_choice
@tool_choice.setter
def tool_choice(self, value: Any) -> None:
if value is None:
self._tool_choice = "auto"
return
normalized = str(value).strip().lower()
allowed = {"auto", "none", "required"}
if normalized not in allowed:
raise ValueError(f"Unsupported tool choice: {value}")
self._tool_choice = normalized
def _get_llm_config(self) -> LLMConfig:
if self._resolver:
@ -159,7 +191,7 @@ class DepartmentAgent:
primary_endpoint,
messages,
tools=tools,
tool_choice="auto",
tool_choice=self._tool_choice,
)
except LLMError as exc:
LOGGER.warning(

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import json
import math
import copy
from dataclasses import dataclass, replace
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
@ -18,17 +19,27 @@ LOG_EXTRA = {"stage": "decision_env"}
@dataclass(frozen=True)
class ParameterSpec:
"""Defines how a scalar action dimension maps to strategy parameters."""
"""Defines how an action dimension maps to strategy parameters or behaviors."""
name: str
target: str
minimum: float = 0.0
maximum: float = 1.0
values: Optional[Sequence[Any]] = None
def clamp(self, value: float) -> float:
clipped = max(0.0, min(1.0, float(value)))
return self.minimum + clipped * (self.maximum - self.minimum)
def resolve(self, value: float) -> Any:
if self.values is not None:
if not self.values:
raise ValueError(f"ParameterSpec {self.name} configured with empty values list")
clipped = max(0.0, min(1.0, float(value)))
index = int(round(clipped * (len(self.values) - 1)))
return self.values[index]
return self.clamp(value)
@dataclass
class EpisodeMetrics:
@ -68,6 +79,7 @@ class DecisionEnv:
self._reward_fn = reward_fn or self._default_reward
self._last_metrics: Optional[EpisodeMetrics] = None
self._last_action: Optional[Tuple[float, ...]] = None
self._last_department_controls: Optional[Dict[str, Dict[str, Any]]] = None
self._episode = 0
self._disable_departments = bool(disable_departments)
@ -75,10 +87,15 @@ class DecisionEnv:
def action_dim(self) -> int:
return len(self._specs)
@property
def last_department_controls(self) -> Optional[Dict[str, Dict[str, Any]]]:
return self._last_department_controls
def reset(self) -> Dict[str, float]:
self._episode += 1
self._last_metrics = None
self._last_action = None
self._last_department_controls = None
return {
"episode": float(self._episode),
"baseline_return": 0.0,
@ -90,14 +107,24 @@ class DecisionEnv:
action_array = [float(val) for val in action]
self._last_action = tuple(action_array)
weights = self._build_weights(action_array)
LOGGER.info("episode=%s action=%s weights=%s", self._episode, action_array, weights, extra=LOG_EXTRA)
weights, department_controls = self._prepare_actions(action_array)
LOGGER.info(
"episode=%s action=%s weights=%s controls=%s",
self._episode,
action_array,
weights,
department_controls,
extra=LOG_EXTRA,
)
cfg = replace(self._template_cfg)
engine = BacktestEngine(cfg)
engine.weights = weight_map(weights)
if self._disable_departments:
engine.department_manager = None
applied_controls: Dict[str, Dict[str, Any]] = {}
else:
applied_controls = self._apply_department_controls(engine, department_controls)
self._clear_portfolio_records()
@ -135,19 +162,153 @@ class DecisionEnv:
"risk_events": getattr(result, "risk_events", []),
"portfolio_snapshots": snapshots,
"portfolio_trades": trades_override,
"department_controls": applied_controls,
}
self._last_department_controls = applied_controls
return observation, reward, True, info
def _build_weights(self, action: Sequence[float]) -> Dict[str, float]:
def _prepare_actions(
self,
action: Sequence[float],
) -> Tuple[Dict[str, float], Dict[str, Dict[str, Any]]]:
weights = dict(self._baseline_weights)
department_controls: Dict[str, Dict[str, Any]] = {}
for idx, spec in enumerate(self._specs):
value = spec.clamp(action[idx])
try:
resolved = spec.resolve(action[idx])
except ValueError as exc:
LOGGER.warning("参数 %s 解析失败:%s", spec.name, exc, extra=LOG_EXTRA)
continue
if spec.target.startswith("agent_weights."):
agent_name = spec.target.split(".", 1)[1]
weights[agent_name] = value
try:
weights[agent_name] = float(resolved)
except (TypeError, ValueError):
LOGGER.debug(
"spec %s produced non-numeric weight %s; skipping",
spec.name,
resolved,
extra=LOG_EXTRA,
)
continue
if spec.target.startswith("department."):
target_path = spec.target.split(".")[1:]
if len(target_path) < 2:
LOGGER.debug("未识别的部门目标:%s", spec.target, extra=LOG_EXTRA)
continue
dept_code = target_path[0]
field = ".".join(target_path[1:])
dept_controls = department_controls.setdefault(dept_code, {})
dept_controls[field] = resolved
continue
else:
LOGGER.debug("暂未支持的参数目标:%s", spec.target, extra=LOG_EXTRA)
return weights
return weights, department_controls
def _apply_department_controls(
self,
engine: BacktestEngine,
controls: Mapping[str, Mapping[str, Any]],
) -> Dict[str, Dict[str, Any]]:
manager = getattr(engine, "department_manager", None)
if not manager or not getattr(manager, "agents", None):
return {}
applied: Dict[str, Dict[str, Any]] = {}
for dept_code, payload in controls.items():
agent = manager.agents.get(dept_code)
if not agent or not isinstance(payload, Mapping):
continue
applied_fields: Dict[str, Any] = {}
# Ensure mutable settings clone to avoid global side-effects
try:
original_settings = agent.settings
cloned_settings = replace(original_settings)
cloned_settings.llm = copy.deepcopy(original_settings.llm)
agent.settings = cloned_settings
except Exception as exc: # noqa: BLE001
LOGGER.warning(
"复制部门 %s 配置失败:%s",
dept_code,
exc,
extra=LOG_EXTRA,
)
continue
for raw_field, value in payload.items():
field = raw_field.lower()
if field == "function_policy":
field = "tool_choice"
if field in {"prompt", "instruction"}:
agent.settings.prompt = str(value)
applied_fields[field] = agent.settings.prompt
continue
if field == "description":
agent.settings.description = str(value)
applied_fields[field] = agent.settings.description
continue
if field in {"prompt_template_id", "prompt_template"}:
agent.settings.prompt_template_id = str(value)
applied_fields["prompt_template_id"] = agent.settings.prompt_template_id
continue
if field == "prompt_template_version":
agent.settings.prompt_template_version = str(value)
applied_fields["prompt_template_version"] = agent.settings.prompt_template_version
continue
if field in {"temperature", "llm.temperature"}:
try:
temperature = max(0.0, min(2.0, float(value)))
agent.settings.llm.primary.temperature = temperature
applied_fields["temperature"] = temperature
except (TypeError, ValueError):
LOGGER.debug(
"无效的温度值 %s for %s",
value,
dept_code,
extra=LOG_EXTRA,
)
continue
if field in {"tool_choice", "tool_strategy"}:
try:
agent.tool_choice = value
applied_fields["tool_choice"] = agent.tool_choice
except ValueError:
LOGGER.debug(
"部门 %s 工具策略 %s 无效",
dept_code,
value,
extra=LOG_EXTRA,
)
continue
if field == "max_rounds":
try:
agent.max_rounds = value
applied_fields["max_rounds"] = agent.max_rounds
except ValueError:
LOGGER.debug(
"部门 %s max_rounds %s 无效",
dept_code,
value,
extra=LOG_EXTRA,
)
continue
if field == "prompt_template_override":
agent.settings.prompt = str(value)
applied_fields["prompt"] = agent.settings.prompt
continue
LOGGER.debug(
"部门 %s 未识别的控制项 %s",
dept_code,
raw_field,
extra=LOG_EXTRA,
)
if applied_fields:
applied[dept_code] = applied_fields
return applied
def _compute_metrics(
self,

View File

@ -252,98 +252,352 @@ class TemplateRegistry:
DEFAULT_TEMPLATES = {
"department_base": {
"name": "部门基础模板",
"description": "通用的部门分析提示模板",
"description": "所有部门通用的审慎分析提示词骨架",
"template": """
部门名称{title}
部门{title}
股票代码{ts_code}
交易日{trade_date}
角色说明{description}
职责指令{instruction}
角色定位
- 角色说明{description}
- 行动守则{instruction}
可用数据范围
数据边界
- 可用字段
{data_scope}
核心特征
- 核心特征
{features}
市场背景
- 市场背景
{market_snapshot}
追加数据
- 追加数据
{supplements}
请基于以上数据给出该部门对当前股票的操作建议输出必须是 JSON字段如下
分析步骤
1. 判断信息是否充分如不充分请说明缺口并优先调用工具 `fetch_data`仅限 `daily``daily_basic`
2. 梳理 2-3 个关键支撑信号与潜在风险确保基于提供的数据
3. 结合量化证据与限制条件给出操作建议和信心来源避免主观臆测
输出要求
仅返回一个 JSON 对象不要添加额外文本
{{
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
"confidence": 0-1 之间的小数表示信心,
"summary": "一句话概括理由",
"signals": ["详细要点", "..."],
"risks": ["风险点", "..."]
"confidence": 0-1 之间的小数
"summary": "一句话结论",
"signals": ["关键支撑要点", "..."],
"risks": ["关键风险", "..."]
}}
如需额外数据请调用工具 `fetch_data`仅支持请求 `daily` `daily_basic`
请严格返回单个 JSON 对象不要添加额外文本
如需说明未完成的数据请求请在 `risks` `signals` 中明确
""",
"variables": [
"title", "ts_code", "trade_date", "description", "instruction",
"data_scope", "features", "market_snapshot", "supplements"
"title",
"ts_code",
"trade_date",
"description",
"instruction",
"data_scope",
"features",
"market_snapshot",
"supplements",
],
"required_context": [
"ts_code", "trade_date", "features", "market_snapshot"
"ts_code",
"trade_date",
"features",
"market_snapshot",
],
"validation_rules": [
"len(features) > 0",
"len(market_snapshot) > 0"
]
"metadata": {
"category": "department",
"preset": "base",
},
},
"momentum_dept": {
"name": "动量研究部门",
"description": "专注于动量因子分析的部门模板",
"name": "动量研究部门模板",
"description": "围绕价格与量能动量的决策提示",
"template": """
部门名称动量研究部门
部门动量研究部门
股票代码{ts_code}
交易日{trade_date}
角色说明专注于分析股票价格动量成交量动量和技术指标动量
职责指令重点关注以下方面:
1. 价格趋势强度和持续性
2. 成交量配合度
3. 技术指标背离
角色定位
- 专注价格动量成交量共振与技术指标背离
- 保持纪律识别趋势延续与反转风险
可用数据范围
研究重点
1. 多时间窗口动量是否同向
2. 成交量是否验证价格走势
3. 是否出现过热或背离信号
数据边界
- 可用字段
{data_scope}
动量特征
- 动量特征
{features}
市场背景
- 市场背景
{market_snapshot}
追加数据
- 追加数据
{supplements}
请基于以上数据进行动量分析并给出操作建议输出必须是 JSON字段如下
{{
"action": "BUY|BUY_S|BUY_M|BUY_L|SELL|HOLD",
"confidence": 0-1 之间的小数表示信心,
"summary": "一句话概括动量分析结论",
"signals": ["动量信号要点", "..."],
"risks": ["动量风险点", "..."]
}}
请沿用部门基础模板的分析步骤与输出要求重点量化趋势动能和量价配合度
""",
"variables": [
"ts_code", "trade_date", "data_scope",
"features", "market_snapshot", "supplements"
"ts_code",
"trade_date",
"data_scope",
"features",
"market_snapshot",
"supplements",
],
"required_context": [
"ts_code", "trade_date", "features", "market_snapshot"
"ts_code",
"trade_date",
"features",
"market_snapshot",
],
"validation_rules": [
"len(features) > 0",
"'momentum' in ' '.join(features.keys()).lower()"
]
}
"metadata": {
"category": "department",
"preset": "momentum",
},
},
"value_dept": {
"name": "价值评估部门模板",
"description": "衡量估值与盈利质量的提示词",
"template": """
部门价值评估部门
股票代码{ts_code}
交易日{trade_date}
角色定位
- 关注估值分位盈利质量与安全边际
- 从中期配置角度评价当前价格的性价比
研究重点
1. 历史及同业视角的估值位置
2. 盈利与分红的可持续性
3. 潜在的估值修复催化或压制因素
数据边界
- 可用字段
{data_scope}
- 估值与质量特征
{features}
- 市场背景
{market_snapshot}
- 追加数据
{supplements}
请按照部门基础模板的分析步骤输出结论并明确估值安全边际来源
""",
"variables": [
"ts_code",
"trade_date",
"data_scope",
"features",
"market_snapshot",
"supplements",
],
"required_context": [
"ts_code",
"trade_date",
"features",
"market_snapshot",
],
"metadata": {
"category": "department",
"preset": "value",
},
},
"news_dept": {
"name": "新闻情绪部门模板",
"description": "针对舆情热度与事件影响的提示词",
"template": """
部门新闻情绪部门
股票代码{ts_code}
交易日{trade_date}
角色定位
- 监控舆情热度事件驱动与短期情绪
- 评估新闻对价格波动的正负面影响
研究重点
1. 新闻情绪是否集中且持续
2. 主题与行情是否匹配
3. 情绪驱动的风险敞口
数据边界
- 可用字段
{data_scope}
- 舆情特征
{features}
- 市场背景
{market_snapshot}
- 追加数据
{supplements}
请遵循部门基础模板的分析步骤突出情绪驱动的力度与时效性
""",
"variables": [
"ts_code",
"trade_date",
"data_scope",
"features",
"market_snapshot",
"supplements",
],
"required_context": [
"ts_code",
"trade_date",
"features",
"market_snapshot",
],
"metadata": {
"category": "department",
"preset": "news",
},
},
"liquidity_dept": {
"name": "流动性评估部门模板",
"description": "衡量成交活跃度与执行成本的提示词",
"template": """
部门流动性评估部门
股票代码{ts_code}
交易日{trade_date}
角色定位
- 评估成交活跃度交易成本与可执行性
- 提醒潜在的流动性风险与仓位限制
研究重点
1. 当前成交量与历史均值的对比
2. 价量限制涨跌停停牌等对执行的影响
3. 预估滑点与转手难度
数据边界
- 可用字段
{data_scope}
- 流动性特征
{features}
- 市场背景
{market_snapshot}
- 追加数据
{supplements}
请遵循部门基础模板的分析步骤重点描述执行可行性与仓位建议
""",
"variables": [
"ts_code",
"trade_date",
"data_scope",
"features",
"market_snapshot",
"supplements",
],
"required_context": [
"ts_code",
"trade_date",
"features",
"market_snapshot",
],
"metadata": {
"category": "department",
"preset": "liquidity",
},
},
"macro_dept": {
"name": "宏观研究部门模板",
"description": "宏观与行业景气度分析提示词",
"template": """
部门宏观研究部门
股票代码{ts_code}
交易日{trade_date}
角色定位
- 追踪宏观周期行业景气与相对强弱
- 评估宏观事件对该标的的方向性影响
研究重点
1. 行业相对大盘的表现与热点程度
2. 宏观/政策事件对行业或标的的指引
3. 需警惕的宏观风险与流动性环境
数据边界
- 可用字段
{data_scope}
- 宏观特征
{features}
- 市场背景
{market_snapshot}
- 追加数据
{supplements}
请执行部门基础模板的分析步骤并输出宏观驱动的信号与风险
""",
"variables": [
"ts_code",
"trade_date",
"data_scope",
"features",
"market_snapshot",
"supplements",
],
"required_context": [
"ts_code",
"trade_date",
"features",
"market_snapshot",
],
"metadata": {
"category": "department",
"preset": "macro",
},
},
"risk_dept": {
"name": "风险控制部门模板",
"description": "识别极端风险与限制条件的提示词",
"template": """
部门风险控制部门
股票代码{ts_code}
交易日{trade_date}
角色定位
- 防范停牌涨跌停仓位与合规限制
- 必要时对高风险决策行使否决权
研究重点
1. 交易限制或异常波动情况
2. 仓位集中度或风险指标是否触顶
3. 潜在的黑天鹅或执行障碍
数据边界
- 可用字段
{data_scope}
- 风险特征
{features}
- 市场背景
{market_snapshot}
- 追加数据
{supplements}
请按照部门基础模板的分析步骤必要时明确阻止交易的理由
""",
"variables": [
"ts_code",
"trade_date",
"data_scope",
"features",
"market_snapshot",
"supplements",
],
"required_context": [
"ts_code",
"trade_date",
"features",
"market_snapshot",
],
"metadata": {
"category": "department",
"preset": "risk",
},
},
}

View File

@ -331,6 +331,7 @@ def render_backtest_review() -> None:
)
specs: List[ParameterSpec] = []
spec_labels: List[str] = []
action_values: List[float] = []
range_valid = True
for idx, agent_name in enumerate(selected_agents):
@ -374,15 +375,111 @@ def render_backtest_review() -> None:
maximum=max_val,
)
)
spec_labels.append(f"agent:{agent_name}")
action_values.append(action_val)
controls_valid = True
with st.expander("部门 LLM 参数", expanded=False):
dept_codes = sorted(app_cfg.departments.keys())
if not dept_codes:
st.caption("当前未配置部门。")
else:
selected_departments = st.multiselect(
"选择需要调整的部门",
dept_codes,
default=[],
key="decision_env_departments",
)
tool_policy_values = ["auto", "none", "required"]
for dept_code in selected_departments:
settings = app_cfg.departments.get(dept_code)
if not settings:
continue
st.subheader(f"部门:{settings.title or dept_code}")
base_temp = 0.2
if settings.llm and settings.llm.primary and settings.llm.primary.temperature is not None:
base_temp = float(settings.llm.primary.temperature)
prefix = f"decision_env_dept_{dept_code}"
col_tmin, col_tmax, col_tslider = st.columns([1, 1, 2])
temp_min = col_tmin.number_input(
"温度最小值",
min_value=0.0,
max_value=2.0,
value=max(0.0, base_temp - 0.3),
step=0.05,
key=f"{prefix}_temp_min",
)
temp_max = col_tmax.number_input(
"温度最大值",
min_value=0.0,
max_value=2.0,
value=min(2.0, base_temp + 0.3),
step=0.05,
key=f"{prefix}_temp_max",
)
if temp_max <= temp_min:
controls_valid = False
st.warning("温度最大值必须大于最小值。")
temp_max = min(2.0, temp_min + 0.01)
span = temp_max - temp_min
if span <= 0:
ratio_default = 0.0
else:
clamped = min(max(base_temp, temp_min), temp_max)
ratio_default = (clamped - temp_min) / span
temp_action = col_tslider.slider(
"动作值(映射至温度区间)",
min_value=0.0,
max_value=1.0,
value=float(ratio_default),
step=0.01,
key=f"{prefix}_temp_action",
)
specs.append(
ParameterSpec(
name=f"dept_temperature_{dept_code}",
target=f"department.{dept_code}.temperature",
minimum=temp_min,
maximum=temp_max,
)
)
spec_labels.append(f"department:{dept_code}:temperature")
action_values.append(temp_action)
col_tool, col_hint = st.columns([1, 2])
tool_choice = col_tool.selectbox(
"函数调用策略",
tool_policy_values,
index=tool_policy_values.index("auto"),
key=f"{prefix}_tool_choice",
)
col_hint.caption("映射提示0→auto0.5→none1→required。")
if len(tool_policy_values) > 1:
tool_value = tool_policy_values.index(tool_choice) / (len(tool_policy_values) - 1)
else:
tool_value = 0.0
specs.append(
ParameterSpec(
name=f"dept_tool_{dept_code}",
target=f"department.{dept_code}.function_policy",
values=tool_policy_values,
)
)
spec_labels.append(f"department:{dept_code}:tool_choice")
action_values.append(tool_value)
if specs:
st.caption("动作维度顺序:" + "".join(spec_labels))
run_decision_env = st.button("执行单次调参", key="run_decision_env_button")
just_finished_single = False
if run_decision_env:
if not selected_agents:
st.warning("请至少选择一个代理进行调参。")
elif not range_valid:
if not specs:
st.warning("请至少配置一个动作维度(代理或部门参数)")
elif selected_agents and not range_valid:
st.error("请确保所有代理的最大权重大于最小权重。")
elif not controls_valid:
st.error("请修正部门参数的取值范围。")
else:
LOGGER.info(
"离线调参(单次)按钮点击,已选择代理=%s 动作=%s disable_departments=%s",
@ -448,11 +545,11 @@ def render_backtest_review() -> None:
resolved_experiment_id = experiment_id or str(uuid.uuid4())
resolved_strategy = strategy_label or "DecisionEnv"
action_payload = {
name: value
for name, value in zip(selected_agents, action_values)
label: value for label, value in zip(spec_labels, action_values)
}
metrics_payload = dict(observation)
metrics_payload["reward"] = reward
metrics_payload["department_controls"] = info.get("department_controls")
log_success = False
try:
log_tuning_result(
@ -477,12 +574,14 @@ def render_backtest_review() -> None:
"observation": dict(observation),
"reward": float(reward),
"weights": info.get("weights", {}),
"department_controls": info.get("department_controls"),
"actions": action_payload,
"nav_series": info.get("nav_series"),
"trades": info.get("trades"),
"portfolio_snapshots": info.get("portfolio_snapshots"),
"portfolio_trades": info.get("portfolio_trades"),
"risk_breakdown": info.get("risk_breakdown"),
"selected_agents": list(selected_agents),
"spec_labels": list(spec_labels),
"action_values": list(action_values),
"experiment_id": resolved_experiment_id,
"strategy_label": resolved_strategy,
@ -562,6 +661,16 @@ def render_backtest_review() -> None:
with st.expander("风险事件统计", expanded=False):
st.json(risk_breakdown)
department_info = single_result.get("department_controls") or {}
if department_info:
with st.expander("部门控制参数", expanded=False):
st.json(department_info)
action_snapshot = single_result.get("actions") or {}
if action_snapshot:
with st.expander("动作明细", expanded=False):
st.json(action_snapshot)
if st.button("清除单次调参结果", key="clear_decision_env_single"):
st.session_state.pop(_DECISION_ENV_SINGLE_RESULT_KEY, None)
st.success("已清除单次调参结果缓存。")
@ -584,10 +693,12 @@ def render_backtest_review() -> None:
run_batch = st.button("批量执行调参", key="run_decision_env_batch")
batch_just_ran = False
if run_batch:
if not selected_agents:
st.warning("先选择调参代理")
elif not range_valid:
if not specs:
st.warning("至少配置一个动作维度")
elif selected_agents and not range_valid:
st.error("请确保所有代理的最大权重大于最小权重。")
elif not controls_valid:
st.error("请修正部门参数的取值范围。")
else:
LOGGER.info(
"离线调参(批量)按钮点击,已选择代理=%s disable_departments=%s",
@ -693,11 +804,12 @@ def render_backtest_review() -> None:
extra=LOG_EXTRA,
)
action_payload = {
name: value
for name, value in zip(selected_agents, action_vals)
label: value
for label, value in zip(spec_labels, action_vals)
}
metrics_payload = dict(observation)
metrics_payload["reward"] = reward
metrics_payload["department_controls"] = info.get("department_controls")
weights_payload = info.get("weights", {})
try:
log_tuning_result(
@ -713,13 +825,14 @@ def render_backtest_review() -> None:
results.append(
{
"序号": idx,
"动作": action_vals,
"动作": action_payload,
"状态": "ok",
"总收益": observation.get("total_return", 0.0),
"最大回撤": observation.get("max_drawdown", 0.0),
"波动率": observation.get("volatility", 0.0),
"奖励": reward,
"权重": weights_payload,
"部门控制": info.get("department_controls"),
}
)
st.session_state[_DECISION_ENV_BATCH_RESULTS_KEY] = {

View File

@ -18,7 +18,7 @@
- 围绕动量、估值、流动性等核心信号扩展轻量高质量因子集,全部由程序生成,满足端到端自动决策需求。
## 3. 决策优化与强化学习
- 扩展 `DecisionEnv` 的动作空间提示版本、部门温度、function 调用策略等),不仅限于代理权重调节。
- 扩展 `DecisionEnv` 的动作空间提示版本、部门温度、function 调用策略等),不仅限于代理权重调节。
- 引入 Bandit / 贝叶斯优化或 RL 算法探索动作空间,并将 `portfolio_snapshots`、`portfolio_trades` 指标纳入奖励约束。
- 构建实时持仓/成交数据写入链路,使线上监控与离线调参共用同一数据源。
- 借鉴 TradingAgents-CN 的做法:拆分环境与策略、提供训练脚本/配置,并输出丰富的评估指标(如 Sharpe、Sortino、基准对比

View File

@ -7,13 +7,58 @@ import pytest
from app.backtest.decision_env import DecisionEnv, EpisodeMetrics, ParameterSpec
from app.backtest.engine import BacktestResult, BtConfig
from app.utils.config import DepartmentSettings, LLMConfig, LLMEndpoint
class _StubDepartmentAgent:
def __init__(self) -> None:
self._tool_choice = "auto"
self._max_rounds = 3
endpoint = LLMEndpoint(provider="openai", model="mock", temperature=0.2)
self.settings = DepartmentSettings(
code="momentum",
title="Momentum",
description="baseline",
prompt="baseline",
llm=LLMConfig(primary=endpoint),
)
@property
def tool_choice(self) -> str:
return self._tool_choice
@tool_choice.setter
def tool_choice(self, value) -> None:
normalized = str(value).strip().lower()
if normalized not in {"auto", "none", "required"}:
raise ValueError("invalid tool choice")
self._tool_choice = normalized
@property
def max_rounds(self) -> int:
return self._max_rounds
@max_rounds.setter
def max_rounds(self, value) -> None:
numeric = int(round(float(value)))
if numeric < 1:
numeric = 1
if numeric > 6:
numeric = 6
self._max_rounds = numeric
class _StubManager:
def __init__(self) -> None:
self.agents = {"momentum": _StubDepartmentAgent()}
class _StubEngine:
def __init__(self, cfg: BtConfig) -> None: # noqa: D401
self.cfg = cfg
self.weights = {}
self.department_manager = None
self.department_manager = _StubManager()
_StubEngine.last_instance = self
def run(self) -> BacktestResult:
result = BacktestResult()
@ -53,6 +98,9 @@ class _StubEngine:
return result
_StubEngine.last_instance: _StubEngine | None = None
def test_decision_env_returns_risk_metrics(monkeypatch):
cfg = BtConfig(
id="stub",
@ -96,3 +144,68 @@ def test_default_reward_penalizes_metrics():
)
reward = DecisionEnv._default_reward(metrics)
assert reward == pytest.approx(0.1 - (0.5 * 0.2 + 0.05 * 2 + 0.1 * 0.3))
def test_decision_env_department_controls(monkeypatch):
cfg = BtConfig(
id="stub",
name="stub",
start_date=date(2025, 1, 10),
end_date=date(2025, 1, 10),
universe=["000001.SZ"],
params={},
)
specs = [
ParameterSpec(name="w_mom", target="agent_weights.A_mom", minimum=0.0, maximum=1.0),
ParameterSpec(
name="dept_prompt",
target="department.momentum.prompt",
values=["baseline", "aggressive"],
),
ParameterSpec(
name="dept_temp",
target="department.momentum.temperature",
minimum=0.1,
maximum=0.9,
),
ParameterSpec(
name="dept_tool",
target="department.momentum.function_policy",
values=["none", "auto", "required"],
),
ParameterSpec(
name="dept_rounds",
target="department.momentum.max_rounds",
minimum=1,
maximum=5,
),
]
env = DecisionEnv(bt_config=cfg, parameter_specs=specs, baseline_weights={"A_mom": 0.5})
monkeypatch.setattr("app.backtest.decision_env.BacktestEngine", _StubEngine)
monkeypatch.setattr(DecisionEnv, "_clear_portfolio_records", lambda self: None)
monkeypatch.setattr(DecisionEnv, "_fetch_portfolio_records", lambda self: ([], []))
obs, reward, done, info = env.step([0.3, 1.0, 0.75, 0.0, 1.0])
assert done is True
assert obs["total_return"] == pytest.approx(0.0)
controls = info["department_controls"]
assert "momentum" in controls
momentum_ctrl = controls["momentum"]
assert momentum_ctrl["prompt"] == "aggressive"
assert momentum_ctrl["temperature"] == pytest.approx(0.7, abs=1e-6)
assert momentum_ctrl["tool_choice"] == "none"
assert momentum_ctrl["max_rounds"] == 5
assert env.last_department_controls == controls
engine = _StubEngine.last_instance
assert engine is not None
agent = engine.department_manager.agents["momentum"]
assert agent.settings.prompt == "aggressive"
assert agent.settings.llm.primary.temperature == pytest.approx(0.7, abs=1e-6)
assert agent.tool_choice == "none"
assert agent.max_rounds == 5

View File

@ -169,6 +169,12 @@ def test_default_templates():
assert momentum is not None
assert "动量研究部门" in momentum.name
assert TemplateRegistry.get("value_dept") is not None
assert TemplateRegistry.get("news_dept") is not None
assert TemplateRegistry.get("liquidity_dept") is not None
assert TemplateRegistry.get("macro_dept") is not None
assert TemplateRegistry.get("risk_dept") is not None
# Validate template content
assert all("{" + var + "}" in dept_base.template for var in dept_base.variables)
assert all("{" + var + "}" in momentum.template for var in momentum.variables)