llm-quant/app/agents/departments.py
2025-10-05 17:03:21 +08:00

760 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Department-level LLM agents coordinating multi-model decisions."""
from __future__ import annotations
import hashlib
import json
from dataclasses import dataclass, field
from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple
from app.agents.base import AgentAction
from app.llm.client import (
call_endpoint_with_messages,
resolve_endpoint,
run_llm_with_config,
LLMError,
)
from app.llm.prompts import department_prompt
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
from app.utils.logging import get_logger, get_conversation_logger
from app.utils.data_access import DataBroker
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "department"}
CONV_LOGGER = get_conversation_logger()
@dataclass
class TableRequest:
name: str
window: int = 1
trade_date: Optional[str] = None
@dataclass
class DepartmentContext:
"""Structured data fed into a department for decision making."""
ts_code: str
trade_date: str
features: Mapping[str, Any] = field(default_factory=dict)
market_snapshot: Mapping[str, Any] = field(default_factory=dict)
raw: Mapping[str, Any] = field(default_factory=dict)
@dataclass
class DepartmentDecision:
"""Result produced by a department agent."""
department: str
action: AgentAction
confidence: float
summary: str
raw_response: str
signals: List[str] = field(default_factory=list)
risks: List[str] = field(default_factory=list)
supplements: List[Dict[str, Any]] = field(default_factory=list)
dialogue: List[str] = field(default_factory=list)
telemetry: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"department": self.department,
"action": self.action.value,
"confidence": self.confidence,
"summary": self.summary,
"signals": self.signals,
"risks": self.risks,
"raw_response": self.raw_response,
"supplements": self.supplements,
"dialogue": self.dialogue,
"telemetry": self.telemetry,
}
class DepartmentAgent:
"""Wraps LLM ensemble logic for a single analytical department."""
ALLOWED_TABLES: ClassVar[List[str]] = [
"daily",
"daily_basic",
"stk_limit",
"suspend",
"heat_daily",
"news",
"index_daily",
]
def __init__(
self,
settings: DepartmentSettings,
resolver: Optional[Callable[[DepartmentSettings], LLMConfig]] = None,
) -> None:
self.settings = settings
self._resolver = resolver
self._broker = DataBroker()
self._max_rounds = 3
def _get_llm_config(self) -> LLMConfig:
if self._resolver:
return self._resolver(self.settings)
return self.settings.llm
def analyze(self, context: DepartmentContext) -> DepartmentDecision:
mutable_context = _ensure_mutable_context(context)
system_prompt = (
"你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。"
)
llm_cfg = self._get_llm_config()
if llm_cfg.strategy not in (None, "", "single") or llm_cfg.ensemble:
LOGGER.warning(
"部门 %s 当前配置不支持函数调用模式,回退至传统提示",
self.settings.code,
extra=LOG_EXTRA,
)
return self._analyze_legacy(mutable_context, system_prompt)
tools = self._build_tools()
messages: List[Dict[str, object]] = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
prompt_body = department_prompt(self.settings, mutable_context)
prompt_checksum = hashlib.sha1(prompt_body.encode("utf-8")).hexdigest()
prompt_preview = prompt_body[:240]
messages.append({"role": "user", "content": prompt_body})
transcript: List[str] = []
delivered_requests: set[Tuple[str, int, str]] = set()
primary_endpoint = llm_cfg.primary
try:
resolved_primary = resolve_endpoint(primary_endpoint)
except LLMError as exc:
LOGGER.warning(
"部门 %s 无法解析 LLM 端点,回退传统提示:%s",
self.settings.code,
exc,
extra=LOG_EXTRA,
)
return self._analyze_legacy(mutable_context, system_prompt)
final_message: Optional[Dict[str, Any]] = None
usage_records: List[Dict[str, Any]] = []
tool_call_records: List[Dict[str, Any]] = []
rounds_executed = 0
CONV_LOGGER.info(
"dept=%s ts_code=%s trade_date=%s start",
self.settings.code,
context.ts_code,
context.trade_date,
)
for round_idx in range(self._max_rounds):
try:
response = call_endpoint_with_messages(
primary_endpoint,
messages,
tools=tools,
tool_choice="auto",
)
except LLMError as exc:
LOGGER.warning(
"部门 %s 函数调用失败,回退传统提示:%s",
self.settings.code,
exc,
extra=LOG_EXTRA,
)
return self._analyze_legacy(mutable_context, system_prompt)
rounds_executed = round_idx + 1
usage = response.get("usage") if isinstance(response, Mapping) else None
if isinstance(usage, Mapping):
usage_payload = {"round": round_idx + 1}
usage_payload.update(dict(usage))
usage_records.append(usage_payload)
choice = (response.get("choices") or [{}])[0]
message = choice.get("message", {})
transcript.append(_message_to_text(message))
assistant_record: Dict[str, Any] = {
"role": "assistant",
"content": _extract_message_content(message),
}
if message.get("tool_calls"):
assistant_record["tool_calls"] = message.get("tool_calls")
messages.append(assistant_record)
CONV_LOGGER.info(
"dept=%s round=%s assistant=%s",
self.settings.code,
round_idx + 1,
assistant_record,
)
tool_calls = message.get("tool_calls") or []
if tool_calls:
for call in tool_calls:
function_block = call.get("function") or {}
tool_response, delivered = self._handle_tool_call(
mutable_context,
call,
delivered_requests,
round_idx,
)
tables_summary: List[Dict[str, Any]] = []
for item in tool_response.get("results") or []:
if isinstance(item, Mapping):
tables_summary.append(
{
"table": item.get("table"),
"window": item.get("window"),
"trade_date": item.get("trade_date"),
"row_count": len(item.get("rows") or []),
}
)
tool_call_records.append(
{
"round": round_idx + 1,
"id": call.get("id"),
"name": function_block.get("name"),
"arguments": function_block.get("arguments"),
"status": tool_response.get("status"),
"results": len(tool_response.get("results") or []),
"tables": tables_summary,
"skipped": list(tool_response.get("skipped") or []),
}
)
transcript.append(
json.dumps({"tool_response": tool_response}, ensure_ascii=False)
)
messages.append(
{
"role": "tool",
"tool_call_id": call.get("id"),
"content": json.dumps(tool_response, ensure_ascii=False),
}
)
delivered_requests.update(delivered)
CONV_LOGGER.info(
"dept=%s round=%s tool_call=%s response=%s",
self.settings.code,
round_idx + 1,
call,
tool_response,
)
continue
final_message = message
break
if final_message is None:
LOGGER.warning(
"部门 %s 函数调用达到轮次上限仍未返回文本,使用最后一次消息",
self.settings.code,
extra=LOG_EXTRA,
)
final_message = message
CONV_LOGGER.warning(
"dept=%s rounds_exhausted last_message=%s",
self.settings.code,
final_message,
)
mutable_context.raw["supplement_transcript"] = list(transcript)
content_text = _extract_message_content(final_message)
decision_data = _parse_department_response(content_text)
action = _normalize_action(decision_data.get("action"))
confidence = _clamp_float(decision_data.get("confidence"), default=0.5)
summary = decision_data.get("summary") or decision_data.get("reason") or ""
signals = decision_data.get("signals") or decision_data.get("rationale") or []
if isinstance(signals, str):
signals = [signals]
risks = decision_data.get("risks") or decision_data.get("warnings") or []
if isinstance(risks, str):
risks = [risks]
def _safe_int(value: Any) -> int:
try:
return int(value)
except (TypeError, ValueError): # noqa: PERF203 - clarity
return 0
prompt_tokens_total = 0
completion_tokens_total = 0
total_tokens_reported = 0
for usage_payload in usage_records:
prompt_tokens_total += _safe_int(
usage_payload.get("prompt_tokens")
or usage_payload.get("prompt_tokens_total")
)
completion_tokens_total += _safe_int(
usage_payload.get("completion_tokens")
or usage_payload.get("completion_tokens_total")
)
reported_total = _safe_int(
usage_payload.get("total_tokens")
or usage_payload.get("total_tokens_total")
)
if reported_total:
total_tokens_reported += reported_total
total_tokens = (
total_tokens_reported
if total_tokens_reported
else prompt_tokens_total + completion_tokens_total
)
telemetry: Dict[str, Any] = {
"provider": resolved_primary.get("provider_key"),
"model": resolved_primary.get("model"),
"temperature": resolved_primary.get("temperature"),
"timeout": resolved_primary.get("timeout"),
"endpoint_prompt_template": resolved_primary.get("prompt_template"),
"rounds": rounds_executed,
"tool_call_count": len(tool_call_records),
"tool_trace": tool_call_records,
"usage_by_round": usage_records,
"tokens": {
"prompt": prompt_tokens_total,
"completion": completion_tokens_total,
"total": total_tokens,
},
"prompt": {
"checksum": prompt_checksum,
"length": len(prompt_body),
"preview": prompt_preview,
"role_description": self.settings.description,
"instruction": self.settings.prompt,
"system": system_prompt,
},
"messages_exchanged": len(messages),
"supplement_rounds": len(tool_call_records),
}
decision = DepartmentDecision(
department=self.settings.code,
action=action,
confidence=confidence,
summary=summary or "未提供摘要",
signals=[str(sig) for sig in signals if sig],
risks=[str(risk) for risk in risks if risk],
raw_response=content_text,
supplements=list(mutable_context.raw.get("supplement_data", [])),
dialogue=list(transcript),
telemetry=telemetry,
)
LOGGER.debug(
"部门 %s 决策action=%s confidence=%.2f",
self.settings.code,
decision.action.value,
decision.confidence,
extra=LOG_EXTRA,
)
CONV_LOGGER.info(
"dept=%s decision action=%s confidence=%.2f summary=%s",
self.settings.code,
decision.action.value,
decision.confidence,
summary or "",
)
CONV_LOGGER.info(
"dept=%s telemetry=%s",
self.settings.code,
json.dumps(telemetry, ensure_ascii=False),
)
return decision
@staticmethod
def _normalize_trade_date(value: str) -> str:
if not isinstance(value, str):
return str(value)
return value.replace("-", "")
def _fulfill_data_requests(
self,
context: DepartmentContext,
requests: Sequence[TableRequest],
) -> Tuple[List[str], List[Dict[str, Any]], set[Tuple[str, int, str]]]:
lines: List[str] = []
payload: List[Dict[str, Any]] = []
delivered: set[Tuple[str, int, str]] = set()
ts_code = context.ts_code
default_trade_date = self._normalize_trade_date(context.trade_date)
for req in requests:
table = (req.name or "").strip().lower()
if not table:
continue
if table not in self.ALLOWED_TABLES:
lines.append(f"- {table}: 不在允许的表列表中")
continue
trade_date = self._normalize_trade_date(req.trade_date or default_trade_date)
window = max(1, min(req.window or 1, getattr(self._broker, "MAX_WINDOW", 120)))
key = (table, window, trade_date)
if key in delivered:
lines.append(f"- {table}: 已返回窗口 {window} 的数据,跳过重复请求")
continue
rows = self._broker.fetch_table_rows(
table,
ts_code,
trade_date,
window,
auto_refresh=False # 避免在回测过程中触发自动补数
)
if rows:
preview = ", ".join(
f"{row.get('trade_date', 'NA')}" for row in rows[: min(len(rows), 5)]
)
lines.append(
f"- {table} (window={window} trade_date<= {trade_date}): 返回 {len(rows)}{preview}"
)
else:
lines.append(
f"- {table} (window={window} trade_date<= {trade_date}): (数据缺失)"
)
payload.append(
{
"table": table,
"window": window,
"trade_date": trade_date,
"rows": rows,
}
)
delivered.add(key)
return lines, payload, delivered
def _handle_tool_call(
self,
context: DepartmentContext,
call: Mapping[str, Any],
delivered_requests: set[Tuple[str, int, str]],
round_idx: int,
) -> Tuple[Dict[str, Any], set[Tuple[str, int, str]]]:
function_block = call.get("function") or {}
name = function_block.get("name") or ""
if name != "fetch_data":
LOGGER.warning(
"部门 %s 收到未知工具调用:%s",
self.settings.code,
name,
extra=LOG_EXTRA,
)
return {
"status": "error",
"message": f"未知工具 {name}",
}, set()
args = _parse_tool_arguments(function_block.get("arguments"))
base_trade_date = self._normalize_trade_date(
args.get("trade_date") or context.trade_date
)
raw_requests = args.get("tables") or []
requests: List[TableRequest] = []
skipped: List[str] = []
for item in raw_requests:
name = str(item.get("name", "")).strip().lower()
if not name:
continue
window_raw = item.get("window")
try:
window = int(window_raw) if window_raw is not None else 1
except (TypeError, ValueError):
window = 1
window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120)))
override_date = item.get("trade_date")
req_date = self._normalize_trade_date(override_date or base_trade_date)
key = (name, window, req_date)
if key in delivered_requests:
skipped.append(name)
continue
requests.append(TableRequest(name=name, window=window, trade_date=req_date))
if not requests:
return {
"status": "ok",
"round": round_idx + 1,
"results": [],
"skipped": skipped,
}, set()
lines, payload, delivered = self._fulfill_data_requests(context, requests)
if payload:
context.raw.setdefault("supplement_data", []).extend(payload)
context.raw.setdefault("supplement_rounds", []).append(
{
"round": round_idx + 1,
"requests": [req.__dict__ for req in requests],
"data": payload,
"notes": lines,
}
)
if lines:
context.raw.setdefault("supplement_notes", []).append(
{
"round": round_idx + 1,
"lines": lines,
}
)
response_payload = {
"status": "ok",
"round": round_idx + 1,
"results": payload,
"notes": lines,
"skipped": skipped,
}
return response_payload, delivered
def _build_tools(self) -> List[Dict[str, Any]]:
max_window = getattr(self._broker, "MAX_WINDOW", 120)
return [
{
"type": "function",
"function": {
"name": "fetch_data",
"description": (
"根据表名请求指定交易日及窗口的历史数据。当前仅支持 'daily''daily_basic' 表。"
),
"parameters": {
"type": "object",
"properties": {
"tables": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"enum": self.ALLOWED_TABLES,
"description": "表名,例如 daily 或 daily_basic",
},
"window": {
"type": "integer",
"minimum": 1,
"maximum": max_window,
"description": "向前回溯的记录条数,默认为 1",
},
"trade_date": {
"type": "string",
"pattern": r"^\\d{8}$",
"description": "覆盖默认交易日(格式 YYYYMMDD",
},
},
"required": ["name"],
},
"minItems": 1,
},
"trade_date": {
"type": "string",
"pattern": r"^\\d{8}$",
"description": "默认交易日(格式 YYYYMMDD",
},
},
"required": ["tables"],
},
},
}
]
def _analyze_legacy(
self,
context: DepartmentContext,
system_prompt: str,
) -> DepartmentDecision:
prompt = department_prompt(self.settings, context)
llm_cfg = self._get_llm_config()
try:
response = run_llm_with_config(llm_cfg, prompt, system=system_prompt)
except Exception as exc: # noqa: BLE001
LOGGER.exception(
"部门 %s 调用 LLM 失败:%s",
self.settings.code,
exc,
extra=LOG_EXTRA,
)
CONV_LOGGER.error(
"dept=%s legacy_call_failed err=%s",
self.settings.code,
exc,
)
return DepartmentDecision(
department=self.settings.code,
action=AgentAction.HOLD,
confidence=0.0,
summary=f"LLM 调用失败:{exc}",
raw_response=str(exc),
)
context.raw["supplement_transcript"] = [response]
CONV_LOGGER.info("dept=%s legacy_response=%s", self.settings.code, response)
decision_data = _parse_department_response(response)
action = _normalize_action(decision_data.get("action"))
confidence = _clamp_float(decision_data.get("confidence"), default=0.5)
summary = decision_data.get("summary") or decision_data.get("reason") or ""
signals = decision_data.get("signals") or decision_data.get("rationale") or []
if isinstance(signals, str):
signals = [signals]
risks = decision_data.get("risks") or decision_data.get("warnings") or []
if isinstance(risks, str):
risks = [risks]
decision = DepartmentDecision(
department=self.settings.code,
action=action,
confidence=confidence,
summary=summary or "未提供摘要",
signals=[str(sig) for sig in signals if sig],
risks=[str(risk) for risk in risks if risk],
raw_response=response,
supplements=list(context.raw.get("supplement_data", [])),
dialogue=[response],
)
return decision
def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
if not isinstance(context.features, dict):
context.features = dict(context.features or {})
if not isinstance(context.market_snapshot, dict):
context.market_snapshot = dict(context.market_snapshot or {})
raw = dict(context.raw or {})
scope_values = raw.get("scope_values")
if scope_values is not None and not isinstance(scope_values, dict):
raw["scope_values"] = dict(scope_values)
context.raw = raw
return context
def _parse_tool_arguments(payload: Any) -> Dict[str, Any]:
if isinstance(payload, dict):
return dict(payload)
if isinstance(payload, str):
try:
data = json.loads(payload)
except json.JSONDecodeError:
LOGGER.debug("工具参数解析失败:%s", payload, extra=LOG_EXTRA)
return {}
if isinstance(data, dict):
return data
return {}
def _message_to_text(message: Mapping[str, Any]) -> str:
content = message.get("content")
if isinstance(content, list):
parts: List[str] = []
for item in content:
if isinstance(item, Mapping) and "text" in item:
parts.append(str(item.get("text", "")))
else:
parts.append(str(item))
if parts:
return "".join(parts)
elif isinstance(content, str) and content.strip():
return content
tool_calls = message.get("tool_calls")
if tool_calls:
return json.dumps({"tool_calls": tool_calls}, ensure_ascii=False)
return ""
def _extract_message_content(message: Mapping[str, Any]) -> str:
content = message.get("content")
if isinstance(content, list):
texts = [
str(item.get("text", ""))
for item in content
if isinstance(item, Mapping) and "text" in item
]
if texts:
return "".join(texts)
if isinstance(content, str):
return content
return json.dumps(message, ensure_ascii=False)
class DepartmentManager:
"""Orchestrates all departments defined in configuration."""
def __init__(self, config: AppConfig) -> None:
self.config = config
self.agents: Dict[str, DepartmentAgent] = {
code: DepartmentAgent(settings, self._resolve_llm)
for code, settings in config.departments.items()
}
def evaluate(self, context: DepartmentContext) -> Dict[str, DepartmentDecision]:
results: Dict[str, DepartmentDecision] = {}
for code, agent in self.agents.items():
raw_base = dict(context.raw or {})
if "scope_values" in raw_base:
raw_base["scope_values"] = dict(raw_base.get("scope_values") or {})
dept_context = DepartmentContext(
ts_code=context.ts_code,
trade_date=context.trade_date,
features=dict(context.features or {}),
market_snapshot=dict(context.market_snapshot or {}),
raw=raw_base,
)
results[code] = agent.analyze(dept_context)
return results
def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig:
return settings.llm
def _parse_department_response(text: str) -> Dict[str, Any]:
"""Extract a JSON object from the LLM response if possible."""
cleaned = text.strip()
candidate = None
if cleaned.startswith("{") and cleaned.endswith("}"):
candidate = cleaned
else:
start = cleaned.find("{")
end = cleaned.rfind("}")
if start != -1 and end != -1 and end > start:
candidate = cleaned[start : end + 1]
if candidate:
try:
return json.loads(candidate)
except json.JSONDecodeError:
LOGGER.debug("部门响应 JSON 解析失败,返回原始文本", extra=LOG_EXTRA)
return {"summary": cleaned}
def _normalize_action(value: Any) -> AgentAction:
if isinstance(value, str):
upper = value.strip().upper()
mapping = {
"BUY": AgentAction.BUY_M,
"BUY_S": AgentAction.BUY_S,
"BUY_M": AgentAction.BUY_M,
"BUY_L": AgentAction.BUY_L,
"SELL": AgentAction.SELL,
"HOLD": AgentAction.HOLD,
}
if upper in mapping:
return mapping[upper]
if "SELL" in upper:
return AgentAction.SELL
if "BUY" in upper:
return AgentAction.BUY_M
return AgentAction.HOLD
def _clamp_float(value: Any, default: float = 0.5) -> float:
try:
num = float(value)
except (TypeError, ValueError):
return default
return max(0.0, min(1.0, num))