This commit is contained in:
sam 2025-09-29 13:12:46 +08:00
parent 20bbcbb898
commit 2e98e81715
6 changed files with 339 additions and 52 deletions

View File

@ -1,12 +1,18 @@
"""Department-level LLM agents coordinating multi-model decisions."""
from __future__ import annotations
import hashlib
import json
from dataclasses import dataclass, field
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple
from app.agents.base import AgentAction
from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError
from app.llm.client import (
call_endpoint_with_messages,
resolve_endpoint,
run_llm_with_config,
LLMError,
)
from app.llm.prompts import department_prompt
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
from app.utils.logging import get_logger, get_conversation_logger
@ -48,6 +54,7 @@ class DepartmentDecision:
risks: List[str] = field(default_factory=list)
supplements: List[Dict[str, Any]] = field(default_factory=list)
dialogue: List[str] = field(default_factory=list)
telemetry: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
@ -60,6 +67,7 @@ class DepartmentDecision:
"raw_response": self.raw_response,
"supplements": self.supplements,
"dialogue": self.dialogue,
"telemetry": self.telemetry,
}
@ -102,18 +110,30 @@ class DepartmentAgent:
messages: List[Dict[str, object]] = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append(
{
"role": "user",
"content": department_prompt(self.settings, mutable_context),
}
)
prompt_body = department_prompt(self.settings, mutable_context)
prompt_checksum = hashlib.sha1(prompt_body.encode("utf-8")).hexdigest()
prompt_preview = prompt_body[:240]
messages.append({"role": "user", "content": prompt_body})
transcript: List[str] = []
delivered_requests: set[Tuple[str, int, str]] = set()
primary_endpoint = llm_cfg.primary
try:
resolved_primary = resolve_endpoint(primary_endpoint)
except LLMError as exc:
LOGGER.warning(
"部门 %s 无法解析 LLM 端点,回退传统提示:%s",
self.settings.code,
exc,
extra=LOG_EXTRA,
)
return self._analyze_legacy(mutable_context, system_prompt)
final_message: Optional[Dict[str, Any]] = None
usage_records: List[Dict[str, Any]] = []
tool_call_records: List[Dict[str, Any]] = []
rounds_executed = 0
CONV_LOGGER.info(
"dept=%s ts_code=%s trade_date=%s start",
self.settings.code,
@ -138,6 +158,14 @@ class DepartmentAgent:
)
return self._analyze_legacy(mutable_context, system_prompt)
rounds_executed = round_idx + 1
usage = response.get("usage") if isinstance(response, Mapping) else None
if isinstance(usage, Mapping):
usage_payload = {"round": round_idx + 1}
usage_payload.update(dict(usage))
usage_records.append(usage_payload)
choice = (response.get("choices") or [{}])[0]
message = choice.get("message", {})
transcript.append(_message_to_text(message))
@ -159,12 +187,36 @@ class DepartmentAgent:
tool_calls = message.get("tool_calls") or []
if tool_calls:
for call in tool_calls:
function_block = call.get("function") or {}
tool_response, delivered = self._handle_tool_call(
mutable_context,
call,
delivered_requests,
round_idx,
)
tables_summary: List[Dict[str, Any]] = []
for item in tool_response.get("results") or []:
if isinstance(item, Mapping):
tables_summary.append(
{
"table": item.get("table"),
"window": item.get("window"),
"trade_date": item.get("trade_date"),
"row_count": len(item.get("rows") or []),
}
)
tool_call_records.append(
{
"round": round_idx + 1,
"id": call.get("id"),
"name": function_block.get("name"),
"arguments": function_block.get("arguments"),
"status": tool_response.get("status"),
"results": len(tool_response.get("results") or []),
"tables": tables_summary,
"skipped": list(tool_response.get("skipped") or []),
}
)
transcript.append(
json.dumps({"tool_response": tool_response}, ensure_ascii=False)
)
@ -216,6 +268,64 @@ class DepartmentAgent:
if isinstance(risks, str):
risks = [risks]
def _safe_int(value: Any) -> int:
try:
return int(value)
except (TypeError, ValueError): # noqa: PERF203 - clarity
return 0
prompt_tokens_total = 0
completion_tokens_total = 0
total_tokens_reported = 0
for usage_payload in usage_records:
prompt_tokens_total += _safe_int(
usage_payload.get("prompt_tokens")
or usage_payload.get("prompt_tokens_total")
)
completion_tokens_total += _safe_int(
usage_payload.get("completion_tokens")
or usage_payload.get("completion_tokens_total")
)
reported_total = _safe_int(
usage_payload.get("total_tokens")
or usage_payload.get("total_tokens_total")
)
if reported_total:
total_tokens_reported += reported_total
total_tokens = (
total_tokens_reported
if total_tokens_reported
else prompt_tokens_total + completion_tokens_total
)
telemetry: Dict[str, Any] = {
"provider": resolved_primary.get("provider_key"),
"model": resolved_primary.get("model"),
"temperature": resolved_primary.get("temperature"),
"timeout": resolved_primary.get("timeout"),
"endpoint_prompt_template": resolved_primary.get("prompt_template"),
"rounds": rounds_executed,
"tool_call_count": len(tool_call_records),
"tool_trace": tool_call_records,
"usage_by_round": usage_records,
"tokens": {
"prompt": prompt_tokens_total,
"completion": completion_tokens_total,
"total": total_tokens,
},
"prompt": {
"checksum": prompt_checksum,
"length": len(prompt_body),
"preview": prompt_preview,
"role_description": self.settings.description,
"instruction": self.settings.prompt,
"system": system_prompt,
},
"messages_exchanged": len(messages),
"supplement_rounds": len(tool_call_records),
}
decision = DepartmentDecision(
department=self.settings.code,
action=action,
@ -226,6 +336,7 @@ class DepartmentAgent:
raw_response=content_text,
supplements=list(mutable_context.raw.get("supplement_data", [])),
dialogue=list(transcript),
telemetry=telemetry,
)
LOGGER.debug(
"部门 %s 决策action=%s confidence=%.2f",
@ -241,6 +352,11 @@ class DepartmentAgent:
decision.confidence,
summary or "",
)
CONV_LOGGER.info(
"dept=%s telemetry=%s",
self.settings.code,
json.dumps(telemetry, ensure_ascii=False),
)
return decision
@staticmethod

View File

@ -324,6 +324,8 @@ class BacktestEngine:
metadata["_supplements"] = dept_decision.supplements
if dept_decision.dialogue:
metadata["_dialogue"] = dept_decision.dialogue
if dept_decision.telemetry:
metadata["_telemetry"] = dept_decision.telemetry
payload_json = {**action_scores, **metadata}
rows.append(
(
@ -355,6 +357,11 @@ class BacktestEngine:
for code, dept in decision.department_decisions.items()
if dept.dialogue
},
"_department_telemetry": {
code: dept.telemetry
for code, dept in decision.department_decisions.items()
if dept.telemetry
},
}
rows.append(
(

View File

@ -105,7 +105,7 @@ def _request_openai_chat(
return response.json()
def _prepare_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
def resolve_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
cfg = get_config()
provider_key = (endpoint.provider or "ollama").lower()
provider_cfg = cfg.llm_providers.get(provider_key)
@ -152,7 +152,7 @@ def _prepare_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
resolved = _prepare_endpoint(endpoint)
resolved = resolve_endpoint(endpoint)
provider_key = resolved["provider_key"]
mode = resolved["mode"]
prompt_template = resolved["prompt_template"]
@ -188,7 +188,7 @@ def call_endpoint_with_messages(
tools: Optional[List[Dict[str, object]]] = None,
tool_choice: Optional[object] = None,
) -> Dict[str, object]:
resolved = _prepare_endpoint(endpoint)
resolved = resolve_endpoint(endpoint)
provider_key = resolved["provider_key"]
mode = resolved["mode"]
base_url = resolved["base_url"]

View File

@ -1,10 +1,11 @@
"""Simple runtime metrics collector for LLM calls."""
from __future__ import annotations
import logging
from collections import deque
from dataclasses import dataclass, field
from threading import Lock
from typing import Deque, Dict, List, Optional
from typing import Callable, Deque, Dict, List, Optional
@dataclass
@ -20,6 +21,9 @@ class _Metrics:
_METRICS = _Metrics()
_LOCK = Lock()
_LISTENERS: List[Callable[[Dict[str, object]], None]] = []
LOGGER = logging.getLogger(__name__)
def record_call(
@ -45,6 +49,7 @@ def record_call(
_METRICS.total_prompt_tokens += int(prompt_tokens)
if completion_tokens:
_METRICS.total_completion_tokens += int(completion_tokens)
_notify_listeners()
def snapshot(reset: bool = False) -> Dict[str, object]:
@ -75,6 +80,7 @@ def reset() -> None:
"""Reset all collected metrics."""
snapshot(reset=True)
_notify_listeners()
def record_decision(
@ -103,6 +109,7 @@ def record_decision(
_METRICS.decision_action_counts[action] = (
_METRICS.decision_action_counts.get(action, 0) + 1
)
_notify_listeners()
def recent_decisions(limit: int = 50) -> List[Dict[str, object]]:
@ -112,3 +119,42 @@ def recent_decisions(limit: int = 50) -> List[Dict[str, object]]:
if limit <= 0:
return []
return list(_METRICS.decisions)[-limit:]
def register_listener(callback: Callable[[Dict[str, object]], None]) -> None:
"""Register a callback invoked whenever metrics change."""
if not callable(callback):
return
with _LOCK:
if callback in _LISTENERS:
should_invoke = False
else:
_LISTENERS.append(callback)
should_invoke = True
if should_invoke:
try:
callback(snapshot())
except Exception: # noqa: BLE001
LOGGER.exception("Metrics listener failed on initial callback")
def unregister_listener(callback: Callable[[Dict[str, object]], None]) -> None:
"""Remove a previously registered metrics callback."""
with _LOCK:
if callback in _LISTENERS:
_LISTENERS.remove(callback)
def _notify_listeners() -> None:
with _LOCK:
listeners = list(_LISTENERS)
if not listeners:
return
data = snapshot()
for callback in listeners:
try:
callback(data)
except Exception: # noqa: BLE001
LOGGER.exception("Metrics listener execution failed")

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import sys
import time
from dataclasses import asdict
from datetime import date, datetime, timedelta
from pathlib import Path
@ -28,9 +29,10 @@ from app.ingest.checker import run_boot_check
from app.ingest.tushare import FetchJob, run_ingestion
from app.llm.client import llm_config_snapshot, run_llm
from app.llm.metrics import (
recent_decisions as llm_recent_decisions,
register_listener as register_llm_metrics_listener,
reset as reset_llm_metrics,
snapshot as snapshot_llm_metrics,
recent_decisions as llm_recent_decisions,
)
from app.utils.config import (
ALLOWED_LLM_STRATEGIES,
@ -49,6 +51,11 @@ from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "ui"}
_SIDEBAR_THROTTLE_SECONDS = 0.75
def _sidebar_metrics_listener(metrics: Dict[str, object]) -> None:
_update_dashboard_sidebar(metrics, throttled=True)
def render_global_dashboard() -> None:
@ -56,54 +63,118 @@ def render_global_dashboard() -> None:
metrics_container = st.sidebar.container()
decisions_container = st.sidebar.container()
st.session_state["dashboard_placeholders"] = (metrics_container, decisions_container)
st.session_state["dashboard_containers"] = (metrics_container, decisions_container)
_ensure_dashboard_elements(metrics_container, decisions_container)
if not st.session_state.get("dashboard_listener_registered"):
register_llm_metrics_listener(_sidebar_metrics_listener)
st.session_state["dashboard_listener_registered"] = True
_update_dashboard_sidebar()
def _update_dashboard_sidebar(metrics: Optional[Dict[str, object]] = None) -> None:
placeholders = st.session_state.get("dashboard_placeholders")
if not placeholders:
def _update_dashboard_sidebar(
metrics: Optional[Dict[str, object]] = None,
*,
throttled: bool = False,
) -> None:
containers = st.session_state.get("dashboard_containers")
if not containers:
return
metrics_container, decisions_container = placeholders
metrics_container, decisions_container = containers
elements = st.session_state.get("dashboard_elements")
if elements is None:
elements = _ensure_dashboard_elements(metrics_container, decisions_container)
if throttled:
now = time.monotonic()
last_update = st.session_state.get("dashboard_last_update_ts", 0.0)
if now - last_update < _SIDEBAR_THROTTLE_SECONDS:
if metrics is not None:
st.session_state["dashboard_pending_metrics"] = metrics
return
st.session_state["dashboard_last_update_ts"] = now
else:
st.session_state["dashboard_last_update_ts"] = time.monotonic()
if metrics is None:
metrics = st.session_state.pop("dashboard_pending_metrics", None)
if metrics is None:
metrics = snapshot_llm_metrics()
else:
st.session_state.pop("dashboard_pending_metrics", None)
metrics = metrics or snapshot_llm_metrics()
metrics_container.empty()
with metrics_container.container():
st.header("系统监控")
col_a, col_b, col_c = st.columns(3)
col_a.metric("LLM 调用", metrics.get("total_calls", 0))
col_b.metric("Prompt Tokens", metrics.get("total_prompt_tokens", 0))
col_c.metric("Completion Tokens", metrics.get("total_completion_tokens", 0))
elements["metrics_calls"].metric("LLM 调用", metrics.get("total_calls", 0))
elements["metrics_prompt"].metric("Prompt Tokens", metrics.get("total_prompt_tokens", 0))
elements["metrics_completion"].metric(
"Completion Tokens", metrics.get("total_completion_tokens", 0)
)
provider_calls = metrics.get("provider_calls", {})
model_calls = metrics.get("model_calls", {})
if provider_calls or model_calls:
with st.expander("调用分布", expanded=False):
if provider_calls:
st.write("按 Provider")
st.json(provider_calls)
if model_calls:
st.write("按模型:")
st.json(model_calls)
provider_calls = metrics.get("provider_calls", {})
model_calls = metrics.get("model_calls", {})
provider_placeholder = elements["provider_distribution"]
provider_placeholder.empty()
if provider_calls:
provider_placeholder.json(provider_calls)
else:
provider_placeholder.info("暂无 Provider 分布数据。")
decisions_container.empty()
with decisions_container.container():
st.subheader("最新决策")
decisions = metrics.get("recent_decisions") or llm_recent_decisions(10)
if decisions:
for record in reversed(decisions[-10:]):
ts_code = record.get("ts_code")
trade_date = record.get("trade_date")
action = record.get("action")
confidence = record.get("confidence", 0.0)
summary = record.get("summary")
st.markdown(
f"**{trade_date} {ts_code}** → {action} (置信度 {confidence:.2f})"
)
if summary:
st.caption(summary)
else:
st.caption("暂无决策记录。执行回测或实时评估后可在此查看。")
model_placeholder = elements["model_distribution"]
model_placeholder.empty()
if model_calls:
model_placeholder.json(model_calls)
else:
model_placeholder.info("暂无模型分布数据。")
decisions = metrics.get("recent_decisions") or llm_recent_decisions(10)
if decisions:
lines = []
for record in reversed(decisions[-10:]):
ts_code = record.get("ts_code")
trade_date = record.get("trade_date")
action = record.get("action")
confidence = record.get("confidence", 0.0)
summary = record.get("summary")
line = f"**{trade_date} {ts_code}** → {action} (置信度 {confidence:.2f})"
if summary:
line += f"\n<small>{summary}</small>"
lines.append(line)
decisions_placeholder = elements["decisions_list"]
decisions_placeholder.empty()
decisions_placeholder.markdown("\n\n".join(lines), unsafe_allow_html=True)
else:
decisions_placeholder = elements["decisions_list"]
decisions_placeholder.empty()
decisions_placeholder.info("暂无决策记录。执行回测或实时评估后可在此查看。")
def _ensure_dashboard_elements(metrics_container, decisions_container) -> Dict[str, object]:
elements = st.session_state.get("dashboard_elements")
if elements:
return elements
metrics_container.header("系统监控")
col_a, col_b, col_c = metrics_container.columns(3)
metrics_calls = col_a.empty()
metrics_prompt = col_b.empty()
metrics_completion = col_c.empty()
distribution_expander = metrics_container.expander("调用分布", expanded=False)
provider_distribution = distribution_expander.empty()
model_distribution = distribution_expander.empty()
decisions_container.subheader("最新决策")
decisions_list = decisions_container.empty()
elements = {
"metrics_calls": metrics_calls,
"metrics_prompt": metrics_prompt,
"metrics_completion": metrics_completion,
"provider_distribution": provider_distribution,
"model_distribution": model_distribution,
"decisions_list": decisions_list,
}
st.session_state["dashboard_elements"] = elements
return elements
def _discover_provider_models(provider: LLMProvider, base_override: str = "", api_override: Optional[str] = None) -> tuple[list[str], Optional[str]]:
"""Attempt to query provider API and return available model ids."""
@ -335,6 +406,7 @@ def render_today_plan() -> None:
"turnover_series": utils.get("_turnover_series", []),
"department_supplements": utils.get("_department_supplements", {}),
"department_dialogue": utils.get("_department_dialogue", {}),
"department_telemetry": utils.get("_department_telemetry", {}),
}
continue
@ -344,6 +416,7 @@ def render_today_plan() -> None:
risks = utils.get("_risks", [])
supplements = utils.get("_supplements", [])
dialogue = utils.get("_dialogue", [])
telemetry = utils.get("_telemetry", {})
dept_records.append(
{
"部门": code,
@ -362,6 +435,7 @@ def render_today_plan() -> None:
"summary": utils.get("_summary", ""),
"signals": signals,
"risks": risks,
"telemetry": telemetry if isinstance(telemetry, dict) else {},
}
else:
score_map = {
@ -407,6 +481,7 @@ def render_today_plan() -> None:
st.json(global_info["turnover_series"])
dept_sup = global_info.get("department_supplements") or {}
dept_dialogue = global_info.get("department_dialogue") or {}
dept_telemetry = global_info.get("department_telemetry") or {}
if dept_sup or dept_dialogue:
with st.expander("部门补数与对话记录", expanded=False):
if dept_sup:
@ -415,6 +490,9 @@ def render_today_plan() -> None:
if dept_dialogue:
st.write("对话片段:")
st.json(dept_dialogue)
if dept_telemetry:
with st.expander("部门 LLM 元数据", expanded=False):
st.json(dept_telemetry)
else:
st.info("暂未写入全局策略摘要。")
@ -437,6 +515,10 @@ def render_today_plan() -> None:
st.markdown(f"**回合 {idx}:** {line}")
else:
st.caption("无额外对话。")
telemetry = details.get("telemetry") or {}
if telemetry:
st.write("LLM 元数据:")
st.json(telemetry)
else:
st.info("暂无部门记录。")

View File

@ -0,0 +1,36 @@
# 决策优化讨论记录
## 核心目标
- 在 `app/agents` 现有规则型与部门 LLM 协同框架上,定义“收益最大化 + 风险约束”的统一指标,用于评估提示词、温度、多智能体博弈和调用策略的优劣。
- 借助 `app/backtest/engine.py` 的日频回测环境,构建可重复的仿真闭环,为强化学习或策略搜索提供离线训练与验证数据。
- 通过优化高层策略Prompt 模板、对话轮次、function 调用模式),而不是直接干预底层行情预测,使 RL 学习关注“决策流程调度”。
## 关键挑战
- 状态空间:每轮博弈涉及市场因子、代理信号、历史行为与日志状态(`app/data/logs`);需压缩为 RL 训练可用的紧凑向量或摘要。
- 奖励设计:直接使用策略收益会导致梯度稀疏,可考虑收益、回撤、成交约束、信心一致性等多目标加权。
- 数据效率:真实交易可探索次数有限,必须依赖离线仿真 + 反事实评估,或引入模型驱动的世界模型缓解分布偏移。
- 多智能体非定常性:规则代理 + LLM 的组合策略随参数调整而漂移,需要稳定 Nash 聚合或引入对手建模。
## 建议路径
- **数据记录**:扩展 `app/agent_utils` 写入 Prompt 版本、温度、function 调用次数、部门信号等元数据,形成 RL 可用轨迹;同时让日志文件(如 `app/data/logs/agent_20250929_0754.log`)对齐这些字段。
- **环境封装**:在 `app/backtest/engine.py` 外层定义 `DecisionEnv`让“动作”映射到策略配置Prompt 选择、温度、投票权重),`step()` 调用现有引擎完成一日博弈并返回奖励。
- **层级策略**:先做 Bandit 或 CMA-ES 等黑箱优化调参(温度、权重)建立基线,再过渡到 PPO/SAC 等连续动作 RL提示词可以编码成可学习 embedding 或有限候选集合。
- **博弈协同**:针对部门间权重、否决策略,引入 centralized training, decentralized execution (CTDE) 思路,共享一个 critic 评估全局奖励actor 负责单部门参数。
- **安全约束**:用 penalty 方法或 Lagrangian 处理仓位/风控约束,确保训练过程中遵守 `A_risk` 设定的停牌、涨跌停逻辑。
## 强化学习框架
- 状态构建:拼接市场特征(因子矩阵降维)、历史动作、日志中 LLM 置信度;必要时用 LSTM/Transformer 编码。
- 动作定义:连续动作控制 `temperature`/`top_p`/投票权重,离散动作选择 Prompt 模板/协作模式majority vs leader组合动作可分解成参数化策略。
- 奖励函数:`收益净值提升 - λ1*回撤 - λ2*成交成本 - λ3*冲突次数`,可把 LLM 置信度一致性或 function 调用成本作为正负项。
- 训练流程:循环“采样配置 → 回放到 BacktestEngine → 记入 Replay Buffer → 更新策略”,必要时采用离线 RLCQL、IQL以利用历史轨迹。
- 评估:对比默认提示/温度设定的收益分布,统计策略稳定性、反事实收益差,并在 `app/ui/streamlit_app.py` 加入实验版本切换和可视化。
## 下一步
- 先补齐日志与数据库字段,确保能完整记录提示参数与决策结果。
- 搭建轻量 Bandit 调参实验,验证不同提示与温度组合对回测收益的影响。
- 设计 RL 环境接口,与现有 BacktestEngine 集成,规划训练与评估脚本。
## 已完成的日志改进
- `agent_utils` 表新增 `_telemetry``_department_telemetry` JSON 字段(存于 `utils` 列内部),记录每个部门的 provider、模型、温度、回合数、工具调用列表与 token 统计,可在 Streamlit “部门意见”详情页展开查看。
- `app/data/logs/agent_*.log` 会追加 `telemetry` 行,保存每轮函数调用的摘要,方便离线分析提示版本与 LLM 配置对决策的影响。
- Streamlit 侧边栏监听 `llm.metrics` 的实时事件,并以 ~0.75 秒节流频率刷新“系统监控”,既保证日志到达后快速更新,也避免刷屏造成 UI 闪烁。