381 lines
13 KiB
Python
381 lines
13 KiB
Python
"""Department-level LLM agents coordinating multi-model decisions."""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
|
||
|
||
from app.agents.base import AgentAction
|
||
from app.llm.client import run_llm_with_config
|
||
from app.llm.prompts import department_prompt
|
||
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
|
||
from app.utils.logging import get_logger
|
||
from app.utils.data_access import DataBroker, parse_field_path
|
||
|
||
LOGGER = get_logger(__name__)
|
||
LOG_EXTRA = {"stage": "department"}
|
||
|
||
|
||
@dataclass
|
||
class DataRequest:
|
||
field: str
|
||
window: int = 1
|
||
|
||
|
||
@dataclass
|
||
class DepartmentContext:
|
||
"""Structured data fed into a department for decision making."""
|
||
|
||
ts_code: str
|
||
trade_date: str
|
||
features: Mapping[str, Any] = field(default_factory=dict)
|
||
market_snapshot: Mapping[str, Any] = field(default_factory=dict)
|
||
raw: Mapping[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
@dataclass
|
||
class DepartmentDecision:
|
||
"""Result produced by a department agent."""
|
||
|
||
department: str
|
||
action: AgentAction
|
||
confidence: float
|
||
summary: str
|
||
raw_response: str
|
||
signals: List[str] = field(default_factory=list)
|
||
risks: List[str] = field(default_factory=list)
|
||
supplements: List[Dict[str, Any]] = field(default_factory=list)
|
||
dialogue: List[str] = field(default_factory=list)
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"department": self.department,
|
||
"action": self.action.value,
|
||
"confidence": self.confidence,
|
||
"summary": self.summary,
|
||
"signals": self.signals,
|
||
"risks": self.risks,
|
||
"raw_response": self.raw_response,
|
||
"supplements": self.supplements,
|
||
"dialogue": self.dialogue,
|
||
}
|
||
|
||
|
||
class DepartmentAgent:
|
||
"""Wraps LLM ensemble logic for a single analytical department."""
|
||
|
||
def __init__(
|
||
self,
|
||
settings: DepartmentSettings,
|
||
resolver: Optional[Callable[[DepartmentSettings], LLMConfig]] = None,
|
||
) -> None:
|
||
self.settings = settings
|
||
self._resolver = resolver
|
||
self._broker = DataBroker()
|
||
self._max_rounds = 3
|
||
|
||
def _get_llm_config(self) -> LLMConfig:
|
||
if self._resolver:
|
||
return self._resolver(self.settings)
|
||
return self.settings.llm
|
||
|
||
def analyze(self, context: DepartmentContext) -> DepartmentDecision:
|
||
mutable_context = _ensure_mutable_context(context)
|
||
system_prompt = (
|
||
"你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。"
|
||
)
|
||
llm_cfg = self._get_llm_config()
|
||
supplement_chunks: List[str] = []
|
||
transcript: List[str] = []
|
||
delivered_requests = {
|
||
(field, 1)
|
||
for field in (mutable_context.raw.get("scope_values") or {}).keys()
|
||
}
|
||
|
||
response = ""
|
||
decision_data: Dict[str, Any] = {}
|
||
for round_idx in range(self._max_rounds):
|
||
supplement_text = "\n\n".join(chunk for chunk in supplement_chunks if chunk)
|
||
prompt = department_prompt(self.settings, mutable_context, supplements=supplement_text)
|
||
try:
|
||
response = run_llm_with_config(llm_cfg, prompt, system=system_prompt)
|
||
except Exception as exc: # noqa: BLE001
|
||
LOGGER.exception(
|
||
"部门 %s 调用 LLM 失败:%s",
|
||
self.settings.code,
|
||
exc,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
return DepartmentDecision(
|
||
department=self.settings.code,
|
||
action=AgentAction.HOLD,
|
||
confidence=0.0,
|
||
summary=f"LLM 调用失败:{exc}",
|
||
raw_response=str(exc),
|
||
)
|
||
|
||
transcript.append(response)
|
||
decision_data = _parse_department_response(response)
|
||
data_requests = _parse_data_requests(decision_data)
|
||
filtered_requests = [
|
||
req
|
||
for req in data_requests
|
||
if (req.field, req.window) not in delivered_requests
|
||
]
|
||
|
||
if filtered_requests and round_idx < self._max_rounds - 1:
|
||
lines, payload, delivered = self._fulfill_data_requests(
|
||
mutable_context, filtered_requests
|
||
)
|
||
if payload:
|
||
supplement_chunks.append(
|
||
f"回合 {round_idx + 1} 追加数据:\n" + "\n".join(lines)
|
||
)
|
||
mutable_context.raw.setdefault("supplement_data", []).extend(payload)
|
||
mutable_context.raw.setdefault("supplement_rounds", []).append(
|
||
{
|
||
"round": round_idx + 1,
|
||
"requests": [req.__dict__ for req in filtered_requests],
|
||
"data": payload,
|
||
}
|
||
)
|
||
delivered_requests.update(delivered)
|
||
decision_data.pop("data_requests", None)
|
||
continue
|
||
LOGGER.debug(
|
||
"部门 %s 数据请求无结果:%s",
|
||
self.settings.code,
|
||
filtered_requests,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
decision_data.pop("data_requests", None)
|
||
break
|
||
|
||
mutable_context.raw["supplement_transcript"] = list(transcript)
|
||
|
||
action = _normalize_action(decision_data.get("action"))
|
||
confidence = _clamp_float(decision_data.get("confidence"), default=0.5)
|
||
summary = decision_data.get("summary") or decision_data.get("reason") or ""
|
||
signals = decision_data.get("signals") or decision_data.get("rationale") or []
|
||
if isinstance(signals, str):
|
||
signals = [signals]
|
||
risks = decision_data.get("risks") or decision_data.get("warnings") or []
|
||
if isinstance(risks, str):
|
||
risks = [risks]
|
||
|
||
decision = DepartmentDecision(
|
||
department=self.settings.code,
|
||
action=action,
|
||
confidence=confidence,
|
||
summary=summary or "未提供摘要",
|
||
signals=[str(sig) for sig in signals if sig],
|
||
risks=[str(risk) for risk in risks if risk],
|
||
raw_response="\n\n".join(transcript) if transcript else response,
|
||
supplements=list(mutable_context.raw.get("supplement_data", [])),
|
||
dialogue=list(transcript),
|
||
)
|
||
LOGGER.debug(
|
||
"部门 %s 决策:action=%s confidence=%.2f",
|
||
self.settings.code,
|
||
decision.action.value,
|
||
decision.confidence,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
return decision
|
||
|
||
@staticmethod
|
||
def _normalize_trade_date(value: str) -> str:
|
||
if not isinstance(value, str):
|
||
return str(value)
|
||
return value.replace("-", "")
|
||
|
||
def _fulfill_data_requests(
|
||
self,
|
||
context: DepartmentContext,
|
||
requests: Sequence[DataRequest],
|
||
) -> Tuple[List[str], List[Dict[str, Any]], set[Tuple[str, int]]]:
|
||
lines: List[str] = []
|
||
payload: List[Dict[str, Any]] = []
|
||
delivered: set[Tuple[str, int]] = set()
|
||
|
||
ts_code = context.ts_code
|
||
trade_date = self._normalize_trade_date(context.trade_date)
|
||
|
||
latest_fields: List[str] = []
|
||
series_requests: List[Tuple[DataRequest, Tuple[str, str]]] = []
|
||
|
||
for req in requests:
|
||
field = req.field.strip()
|
||
if not field:
|
||
continue
|
||
if req.window <= 1:
|
||
if field not in latest_fields:
|
||
latest_fields.append(field)
|
||
delivered.add((field, 1))
|
||
continue
|
||
parsed = parse_field_path(field)
|
||
if not parsed:
|
||
lines.append(f"- {field}: 字段不合法,已忽略")
|
||
continue
|
||
series_requests.append((req, parsed))
|
||
delivered.add((field, req.window))
|
||
|
||
if latest_fields:
|
||
latest_values = self._broker.fetch_latest(ts_code, trade_date, latest_fields)
|
||
for field in latest_fields:
|
||
value = latest_values.get(field)
|
||
if value is None:
|
||
lines.append(f"- {field}: (数据缺失)")
|
||
else:
|
||
lines.append(f"- {field}: {value}")
|
||
payload.append({"field": field, "window": 1, "values": value})
|
||
|
||
for req, parsed in series_requests:
|
||
table, column = parsed
|
||
series = self._broker.fetch_series(
|
||
table,
|
||
column,
|
||
ts_code,
|
||
trade_date,
|
||
window=req.window,
|
||
)
|
||
if series:
|
||
preview = ", ".join(
|
||
f"{dt}:{val:.4f}"
|
||
for dt, val in series[: min(len(series), 5)]
|
||
)
|
||
lines.append(
|
||
f"- {req.field} (window={req.window}): {preview}"
|
||
)
|
||
else:
|
||
lines.append(
|
||
f"- {req.field} (window={req.window}): (数据缺失)"
|
||
)
|
||
payload.append({"field": req.field, "window": req.window, "values": series})
|
||
|
||
return lines, payload, delivered
|
||
|
||
|
||
def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
|
||
if not isinstance(context.features, dict):
|
||
context.features = dict(context.features or {})
|
||
if not isinstance(context.market_snapshot, dict):
|
||
context.market_snapshot = dict(context.market_snapshot or {})
|
||
raw = dict(context.raw or {})
|
||
scope_values = raw.get("scope_values")
|
||
if scope_values is not None and not isinstance(scope_values, dict):
|
||
raw["scope_values"] = dict(scope_values)
|
||
context.raw = raw
|
||
return context
|
||
|
||
|
||
def _parse_data_requests(payload: Mapping[str, Any]) -> List[DataRequest]:
|
||
raw_requests = payload.get("data_requests")
|
||
requests: List[DataRequest] = []
|
||
if not isinstance(raw_requests, list):
|
||
return requests
|
||
seen: set[Tuple[str, int]] = set()
|
||
for item in raw_requests:
|
||
field = ""
|
||
window = 1
|
||
if isinstance(item, str):
|
||
field = item.strip()
|
||
elif isinstance(item, Mapping):
|
||
candidate = item.get("field")
|
||
if candidate is None:
|
||
continue
|
||
field = str(candidate).strip()
|
||
try:
|
||
window = int(item.get("window", 1))
|
||
except (TypeError, ValueError):
|
||
window = 1
|
||
else:
|
||
continue
|
||
if not field:
|
||
continue
|
||
window = max(1, window)
|
||
key = (field, window)
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
requests.append(DataRequest(field=field, window=window))
|
||
return requests
|
||
|
||
|
||
class DepartmentManager:
|
||
"""Orchestrates all departments defined in configuration."""
|
||
|
||
def __init__(self, config: AppConfig) -> None:
|
||
self.config = config
|
||
self.agents: Dict[str, DepartmentAgent] = {
|
||
code: DepartmentAgent(settings, self._resolve_llm)
|
||
for code, settings in config.departments.items()
|
||
}
|
||
|
||
def evaluate(self, context: DepartmentContext) -> Dict[str, DepartmentDecision]:
|
||
results: Dict[str, DepartmentDecision] = {}
|
||
for code, agent in self.agents.items():
|
||
raw_base = dict(context.raw or {})
|
||
if "scope_values" in raw_base:
|
||
raw_base["scope_values"] = dict(raw_base.get("scope_values") or {})
|
||
dept_context = DepartmentContext(
|
||
ts_code=context.ts_code,
|
||
trade_date=context.trade_date,
|
||
features=dict(context.features or {}),
|
||
market_snapshot=dict(context.market_snapshot or {}),
|
||
raw=raw_base,
|
||
)
|
||
results[code] = agent.analyze(dept_context)
|
||
return results
|
||
|
||
def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig:
|
||
return settings.llm
|
||
|
||
|
||
def _parse_department_response(text: str) -> Dict[str, Any]:
|
||
"""Extract a JSON object from the LLM response if possible."""
|
||
|
||
cleaned = text.strip()
|
||
candidate = None
|
||
if cleaned.startswith("{") and cleaned.endswith("}"):
|
||
candidate = cleaned
|
||
else:
|
||
start = cleaned.find("{")
|
||
end = cleaned.rfind("}")
|
||
if start != -1 and end != -1 and end > start:
|
||
candidate = cleaned[start : end + 1]
|
||
if candidate:
|
||
try:
|
||
return json.loads(candidate)
|
||
except json.JSONDecodeError:
|
||
LOGGER.debug("部门响应 JSON 解析失败,返回原始文本", extra=LOG_EXTRA)
|
||
return {"summary": cleaned}
|
||
|
||
|
||
def _normalize_action(value: Any) -> AgentAction:
|
||
if isinstance(value, str):
|
||
upper = value.strip().upper()
|
||
mapping = {
|
||
"BUY": AgentAction.BUY_M,
|
||
"BUY_S": AgentAction.BUY_S,
|
||
"BUY_M": AgentAction.BUY_M,
|
||
"BUY_L": AgentAction.BUY_L,
|
||
"SELL": AgentAction.SELL,
|
||
"HOLD": AgentAction.HOLD,
|
||
}
|
||
if upper in mapping:
|
||
return mapping[upper]
|
||
if "SELL" in upper:
|
||
return AgentAction.SELL
|
||
if "BUY" in upper:
|
||
return AgentAction.BUY_M
|
||
return AgentAction.HOLD
|
||
|
||
|
||
def _clamp_float(value: Any, default: float = 0.5) -> float:
|
||
try:
|
||
num = float(value)
|
||
except (TypeError, ValueError):
|
||
return default
|
||
return max(0.0, min(1.0, num))
|