llm-quant/app/agents/departments.py
2025-09-28 11:28:55 +08:00

183 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Department-level LLM agents coordinating multi-model decisions."""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Mapping, Optional
from app.agents.base import AgentAction
from app.llm.client import run_llm_with_config
from app.llm.prompts import department_prompt
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "department"}
@dataclass
class DepartmentContext:
"""Structured data fed into a department for decision making."""
ts_code: str
trade_date: str
features: Mapping[str, Any] = field(default_factory=dict)
market_snapshot: Mapping[str, Any] = field(default_factory=dict)
raw: Mapping[str, Any] = field(default_factory=dict)
@dataclass
class DepartmentDecision:
"""Result produced by a department agent."""
department: str
action: AgentAction
confidence: float
summary: str
raw_response: str
signals: List[str] = field(default_factory=list)
risks: List[str] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"department": self.department,
"action": self.action.value,
"confidence": self.confidence,
"summary": self.summary,
"signals": self.signals,
"risks": self.risks,
"raw_response": self.raw_response,
}
class DepartmentAgent:
"""Wraps LLM ensemble logic for a single analytical department."""
def __init__(
self,
settings: DepartmentSettings,
resolver: Optional[Callable[[DepartmentSettings], LLMConfig]] = None,
) -> None:
self.settings = settings
self._resolver = resolver
def _get_llm_config(self) -> LLMConfig:
if self._resolver:
return self._resolver(self.settings)
return self.settings.llm
def analyze(self, context: DepartmentContext) -> DepartmentDecision:
prompt = department_prompt(self.settings, context)
system_prompt = (
"你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。"
)
llm_cfg = self._get_llm_config()
try:
response = run_llm_with_config(llm_cfg, prompt, system=system_prompt)
except Exception as exc: # noqa: BLE001
LOGGER.exception("部门 %s 调用 LLM 失败:%s", self.settings.code, exc, extra=LOG_EXTRA)
return DepartmentDecision(
department=self.settings.code,
action=AgentAction.HOLD,
confidence=0.0,
summary=f"LLM 调用失败:{exc}",
raw_response=str(exc),
)
decision_data = _parse_department_response(response)
action = _normalize_action(decision_data.get("action"))
confidence = _clamp_float(decision_data.get("confidence"), default=0.5)
summary = decision_data.get("summary") or decision_data.get("reason") or ""
signals = decision_data.get("signals") or decision_data.get("rationale") or []
if isinstance(signals, str):
signals = [signals]
risks = decision_data.get("risks") or decision_data.get("warnings") or []
if isinstance(risks, str):
risks = [risks]
decision = DepartmentDecision(
department=self.settings.code,
action=action,
confidence=confidence,
summary=summary or "未提供摘要",
signals=[str(sig) for sig in signals if sig],
risks=[str(risk) for risk in risks if risk],
raw_response=response,
)
LOGGER.debug(
"部门 %s 决策action=%s confidence=%.2f",
self.settings.code,
decision.action.value,
decision.confidence,
extra=LOG_EXTRA,
)
return decision
class DepartmentManager:
"""Orchestrates all departments defined in configuration."""
def __init__(self, config: AppConfig) -> None:
self.config = config
self.agents: Dict[str, DepartmentAgent] = {
code: DepartmentAgent(settings, self._resolve_llm)
for code, settings in config.departments.items()
}
def evaluate(self, context: DepartmentContext) -> Dict[str, DepartmentDecision]:
results: Dict[str, DepartmentDecision] = {}
for code, agent in self.agents.items():
results[code] = agent.analyze(context)
return results
def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig:
return settings.llm
def _parse_department_response(text: str) -> Dict[str, Any]:
"""Extract a JSON object from the LLM response if possible."""
cleaned = text.strip()
candidate = None
if cleaned.startswith("{") and cleaned.endswith("}"):
candidate = cleaned
else:
start = cleaned.find("{")
end = cleaned.rfind("}")
if start != -1 and end != -1 and end > start:
candidate = cleaned[start : end + 1]
if candidate:
try:
return json.loads(candidate)
except json.JSONDecodeError:
LOGGER.debug("部门响应 JSON 解析失败,返回原始文本", extra=LOG_EXTRA)
return {"summary": cleaned}
def _normalize_action(value: Any) -> AgentAction:
if isinstance(value, str):
upper = value.strip().upper()
mapping = {
"BUY": AgentAction.BUY_M,
"BUY_S": AgentAction.BUY_S,
"BUY_M": AgentAction.BUY_M,
"BUY_L": AgentAction.BUY_L,
"SELL": AgentAction.SELL,
"HOLD": AgentAction.HOLD,
}
if upper in mapping:
return mapping[upper]
if "SELL" in upper:
return AgentAction.SELL
if "BUY" in upper:
return AgentAction.BUY_M
return AgentAction.HOLD
def _clamp_float(value: Any, default: float = 0.5) -> float:
try:
num = float(value)
except (TypeError, ValueError):
return default
return max(0.0, min(1.0, num))