168 lines
5.6 KiB
Python
168 lines
5.6 KiB
Python
"""Department-level LLM agents coordinating multi-model decisions."""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Dict, List, Mapping
|
||
|
||
from app.agents.base import AgentAction
|
||
from app.llm.client import run_llm_with_config
|
||
from app.llm.prompts import department_prompt
|
||
from app.utils.config import DepartmentSettings
|
||
from app.utils.logging import get_logger
|
||
|
||
LOGGER = get_logger(__name__)
|
||
LOG_EXTRA = {"stage": "department"}
|
||
|
||
|
||
@dataclass
|
||
class DepartmentContext:
|
||
"""Structured data fed into a department for decision making."""
|
||
|
||
ts_code: str
|
||
trade_date: str
|
||
features: Mapping[str, Any] = field(default_factory=dict)
|
||
market_snapshot: Mapping[str, Any] = field(default_factory=dict)
|
||
raw: Mapping[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
@dataclass
|
||
class DepartmentDecision:
|
||
"""Result produced by a department agent."""
|
||
|
||
department: str
|
||
action: AgentAction
|
||
confidence: float
|
||
summary: str
|
||
raw_response: str
|
||
signals: List[str] = field(default_factory=list)
|
||
risks: List[str] = field(default_factory=list)
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"department": self.department,
|
||
"action": self.action.value,
|
||
"confidence": self.confidence,
|
||
"summary": self.summary,
|
||
"signals": self.signals,
|
||
"risks": self.risks,
|
||
"raw_response": self.raw_response,
|
||
}
|
||
|
||
|
||
class DepartmentAgent:
|
||
"""Wraps LLM ensemble logic for a single analytical department."""
|
||
|
||
def __init__(self, settings: DepartmentSettings) -> None:
|
||
self.settings = settings
|
||
|
||
def analyze(self, context: DepartmentContext) -> DepartmentDecision:
|
||
prompt = department_prompt(self.settings, context)
|
||
system_prompt = (
|
||
"你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。"
|
||
)
|
||
try:
|
||
response = run_llm_with_config(self.settings.llm, prompt, system=system_prompt)
|
||
except Exception as exc: # noqa: BLE001
|
||
LOGGER.exception("部门 %s 调用 LLM 失败:%s", self.settings.code, exc, extra=LOG_EXTRA)
|
||
return DepartmentDecision(
|
||
department=self.settings.code,
|
||
action=AgentAction.HOLD,
|
||
confidence=0.0,
|
||
summary=f"LLM 调用失败:{exc}",
|
||
raw_response=str(exc),
|
||
)
|
||
|
||
decision_data = _parse_department_response(response)
|
||
action = _normalize_action(decision_data.get("action"))
|
||
confidence = _clamp_float(decision_data.get("confidence"), default=0.5)
|
||
summary = decision_data.get("summary") or decision_data.get("reason") or ""
|
||
signals = decision_data.get("signals") or decision_data.get("rationale") or []
|
||
if isinstance(signals, str):
|
||
signals = [signals]
|
||
risks = decision_data.get("risks") or decision_data.get("warnings") or []
|
||
if isinstance(risks, str):
|
||
risks = [risks]
|
||
|
||
decision = DepartmentDecision(
|
||
department=self.settings.code,
|
||
action=action,
|
||
confidence=confidence,
|
||
summary=summary or "未提供摘要",
|
||
signals=[str(sig) for sig in signals if sig],
|
||
risks=[str(risk) for risk in risks if risk],
|
||
raw_response=response,
|
||
)
|
||
LOGGER.debug(
|
||
"部门 %s 决策:action=%s confidence=%.2f",
|
||
self.settings.code,
|
||
decision.action.value,
|
||
decision.confidence,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
return decision
|
||
|
||
|
||
class DepartmentManager:
|
||
"""Orchestrates all departments defined in configuration."""
|
||
|
||
def __init__(self, departments: Mapping[str, DepartmentSettings]) -> None:
|
||
self.agents: Dict[str, DepartmentAgent] = {
|
||
code: DepartmentAgent(settings)
|
||
for code, settings in departments.items()
|
||
}
|
||
|
||
def evaluate(self, context: DepartmentContext) -> Dict[str, DepartmentDecision]:
|
||
results: Dict[str, DepartmentDecision] = {}
|
||
for code, agent in self.agents.items():
|
||
results[code] = agent.analyze(context)
|
||
return results
|
||
|
||
|
||
def _parse_department_response(text: str) -> Dict[str, Any]:
|
||
"""Extract a JSON object from the LLM response if possible."""
|
||
|
||
cleaned = text.strip()
|
||
candidate = None
|
||
if cleaned.startswith("{") and cleaned.endswith("}"):
|
||
candidate = cleaned
|
||
else:
|
||
start = cleaned.find("{")
|
||
end = cleaned.rfind("}")
|
||
if start != -1 and end != -1 and end > start:
|
||
candidate = cleaned[start : end + 1]
|
||
if candidate:
|
||
try:
|
||
return json.loads(candidate)
|
||
except json.JSONDecodeError:
|
||
LOGGER.debug("部门响应 JSON 解析失败,返回原始文本", extra=LOG_EXTRA)
|
||
return {"summary": cleaned}
|
||
|
||
|
||
def _normalize_action(value: Any) -> AgentAction:
|
||
if isinstance(value, str):
|
||
upper = value.strip().upper()
|
||
mapping = {
|
||
"BUY": AgentAction.BUY_M,
|
||
"BUY_S": AgentAction.BUY_S,
|
||
"BUY_M": AgentAction.BUY_M,
|
||
"BUY_L": AgentAction.BUY_L,
|
||
"SELL": AgentAction.SELL,
|
||
"HOLD": AgentAction.HOLD,
|
||
}
|
||
if upper in mapping:
|
||
return mapping[upper]
|
||
if "SELL" in upper:
|
||
return AgentAction.SELL
|
||
if "BUY" in upper:
|
||
return AgentAction.BUY_M
|
||
return AgentAction.HOLD
|
||
|
||
|
||
def _clamp_float(value: Any, default: float = 0.5) -> float:
|
||
try:
|
||
num = float(value)
|
||
except (TypeError, ValueError):
|
||
return default
|
||
return max(0.0, min(1.0, num))
|