llm-quant/app/agents/departments.py
2025-09-28 21:43:08 +08:00

749 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Department-level LLM agents coordinating multi-model decisions."""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
from app.agents.base import AgentAction
from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError
from app.llm.prompts import department_prompt
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
from app.utils.logging import get_logger
from app.utils.data_access import DataBroker
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "department"}
@dataclass
class DataRequest:
field: str
window: int = 1
@dataclass
class DepartmentContext:
"""Structured data fed into a department for decision making."""
ts_code: str
trade_date: str
features: Mapping[str, Any] = field(default_factory=dict)
market_snapshot: Mapping[str, Any] = field(default_factory=dict)
raw: Mapping[str, Any] = field(default_factory=dict)
@dataclass
class DepartmentDecision:
"""Result produced by a department agent."""
department: str
action: AgentAction
confidence: float
summary: str
raw_response: str
signals: List[str] = field(default_factory=list)
risks: List[str] = field(default_factory=list)
supplements: List[Dict[str, Any]] = field(default_factory=list)
dialogue: List[str] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"department": self.department,
"action": self.action.value,
"confidence": self.confidence,
"summary": self.summary,
"signals": self.signals,
"risks": self.risks,
"raw_response": self.raw_response,
"supplements": self.supplements,
"dialogue": self.dialogue,
}
class DepartmentAgent:
"""Wraps LLM ensemble logic for a single analytical department."""
def __init__(
self,
settings: DepartmentSettings,
resolver: Optional[Callable[[DepartmentSettings], LLMConfig]] = None,
) -> None:
self.settings = settings
self._resolver = resolver
self._broker = DataBroker()
self._max_rounds = 3
def _get_llm_config(self) -> LLMConfig:
if self._resolver:
return self._resolver(self.settings)
return self.settings.llm
def analyze(self, context: DepartmentContext) -> DepartmentDecision:
mutable_context = _ensure_mutable_context(context)
system_prompt = (
"你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。"
)
llm_cfg = self._get_llm_config()
if llm_cfg.strategy not in (None, "", "single") or llm_cfg.ensemble:
LOGGER.warning(
"部门 %s 当前配置不支持函数调用模式,回退至传统提示",
self.settings.code,
extra=LOG_EXTRA,
)
return self._analyze_legacy(mutable_context, system_prompt)
tools = self._build_tools()
messages: List[Dict[str, object]] = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append(
{
"role": "user",
"content": department_prompt(self.settings, mutable_context),
}
)
transcript: List[str] = []
delivered_requests: set[Tuple[str, int]] = {
(field, 1)
for field in (mutable_context.raw.get("scope_values") or {}).keys()
}
primary_endpoint = llm_cfg.primary
final_message: Optional[Dict[str, Any]] = None
for round_idx in range(self._max_rounds):
try:
response = call_endpoint_with_messages(
primary_endpoint,
messages,
tools=tools,
tool_choice="auto",
)
except LLMError as exc:
LOGGER.warning(
"部门 %s 函数调用失败,回退传统提示:%s",
self.settings.code,
exc,
extra=LOG_EXTRA,
)
return self._analyze_legacy(mutable_context, system_prompt)
choice = (response.get("choices") or [{}])[0]
message = choice.get("message", {})
transcript.append(_message_to_text(message))
tool_calls = message.get("tool_calls") or []
if tool_calls:
for call in tool_calls:
tool_response, delivered = self._handle_tool_call(
mutable_context,
call,
delivered_requests,
round_idx,
)
transcript.append(
json.dumps({"tool_response": tool_response}, ensure_ascii=False)
)
messages.append(
{
"role": "tool",
"tool_call_id": call.get("id"),
"name": call.get("function", {}).get("name"),
"content": json.dumps(tool_response, ensure_ascii=False),
}
)
delivered_requests.update(delivered)
continue
final_message = message
break
if final_message is None:
LOGGER.warning(
"部门 %s 函数调用达到轮次上限仍未返回文本,使用最后一次消息",
self.settings.code,
extra=LOG_EXTRA,
)
final_message = message
mutable_context.raw["supplement_transcript"] = list(transcript)
content_text = _extract_message_content(final_message)
decision_data = _parse_department_response(content_text)
action = _normalize_action(decision_data.get("action"))
confidence = _clamp_float(decision_data.get("confidence"), default=0.5)
summary = decision_data.get("summary") or decision_data.get("reason") or ""
signals = decision_data.get("signals") or decision_data.get("rationale") or []
if isinstance(signals, str):
signals = [signals]
risks = decision_data.get("risks") or decision_data.get("warnings") or []
if isinstance(risks, str):
risks = [risks]
decision = DepartmentDecision(
department=self.settings.code,
action=action,
confidence=confidence,
summary=summary or "未提供摘要",
signals=[str(sig) for sig in signals if sig],
risks=[str(risk) for risk in risks if risk],
raw_response=content_text,
supplements=list(mutable_context.raw.get("supplement_data", [])),
dialogue=list(transcript),
)
LOGGER.debug(
"部门 %s 决策action=%s confidence=%.2f",
self.settings.code,
decision.action.value,
decision.confidence,
extra=LOG_EXTRA,
)
return decision
@staticmethod
def _normalize_trade_date(value: str) -> str:
if not isinstance(value, str):
return str(value)
return value.replace("-", "")
def _fulfill_data_requests(
self,
context: DepartmentContext,
requests: Sequence[DataRequest],
) -> Tuple[List[str], List[Dict[str, Any]], set[Tuple[str, int]]]:
lines: List[str] = []
payload: List[Dict[str, Any]] = []
delivered: set[Tuple[str, int]] = set()
ts_code = context.ts_code
trade_date = self._normalize_trade_date(context.trade_date)
latest_groups: Dict[str, List[str]] = {}
series_requests: List[Tuple[DataRequest, Tuple[str, str]]] = []
values_map, db_alias_map, series_map = _build_context_lookup(context)
for req in requests:
field = req.field.strip()
if not field:
continue
window = req.window
resolved: Optional[Tuple[str, str]] = None
if "." in field:
resolved = self._broker.resolve_field(field)
elif field in db_alias_map:
resolved = db_alias_map[field]
if resolved:
table, column = resolved
canonical = f"{table}.{column}"
if window <= 1:
latest_groups.setdefault(canonical, []).append(field)
delivered.add((field, 1))
delivered.add((canonical, 1))
else:
series_requests.append((req, resolved))
delivered.add((field, window))
delivered.add((canonical, window))
continue
if field in values_map:
value = values_map[field]
if window <= 1:
payload.append(
{
"field": field,
"window": 1,
"source": "context",
"values": [
{
"trade_date": context.trade_date,
"value": value,
}
],
}
)
lines.append(f"- {field}: {value} (来自上下文)")
else:
series = series_map.get(field)
if series:
trimmed = series[: window]
payload.append(
{
"field": field,
"window": window,
"source": "context_series",
"values": [
{"trade_date": dt, "value": val}
for dt, val in trimmed
],
}
)
preview = ", ".join(
f"{dt}:{val:.4f}" for dt, val in trimmed[: min(len(trimmed), 5)]
)
lines.append(
f"- {field} (window={window} 来自上下文序列): {preview}"
)
else:
payload.append(
{
"field": field,
"window": window,
"source": "context",
"values": [
{
"trade_date": context.trade_date,
"value": value,
}
],
"warning": "仅提供当前值,缺少历史序列",
}
)
lines.append(
f"- {field} (window={window}): 仅有当前值 {value}, 无历史序列"
)
delivered.add((field, window))
if field in db_alias_map:
resolved = db_alias_map[field]
canonical = f"{resolved[0]}.{resolved[1]}"
delivered.add((canonical, window))
continue
lines.append(f"- {field}: 字段不存在或不可用")
if latest_groups:
latest_values = self._broker.fetch_latest(
ts_code, trade_date, list(latest_groups.keys())
)
for canonical, aliases in latest_groups.items():
value = latest_values.get(canonical)
if value is None:
lines.append(f"- {canonical}: (数据缺失)")
else:
lines.append(f"- {canonical}: {value}")
for alias in aliases:
payload.append(
{
"field": alias,
"window": 1,
"source": "database",
"values": [
{
"trade_date": trade_date,
"value": value,
}
],
}
)
for req, resolved in series_requests:
table, column = resolved
series = self._broker.fetch_series(
table,
column,
ts_code,
trade_date,
window=req.window,
)
if series:
preview = ", ".join(
f"{dt}:{val:.4f}"
for dt, val in series[: min(len(series), 5)]
)
lines.append(
f"- {req.field} (window={req.window}): {preview}"
)
else:
lines.append(
f"- {req.field} (window={req.window}): (数据缺失)"
)
payload.append(
{
"field": req.field,
"window": req.window,
"source": "database",
"values": [
{"trade_date": dt, "value": val}
for dt, val in series
],
}
)
return lines, payload, delivered
def _handle_tool_call(
self,
context: DepartmentContext,
call: Mapping[str, Any],
delivered_requests: set[Tuple[str, int]],
round_idx: int,
) -> Tuple[Dict[str, Any], set[Tuple[str, int]]]:
function_block = call.get("function") or {}
name = function_block.get("name") or ""
if name != "fetch_data":
LOGGER.warning(
"部门 %s 收到未知工具调用:%s",
self.settings.code,
name,
extra=LOG_EXTRA,
)
return {
"status": "error",
"message": f"未知工具 {name}",
}, set()
args = _parse_tool_arguments(function_block.get("arguments"))
raw_requests = args.get("requests") or []
requests: List[DataRequest] = []
skipped: List[str] = []
for item in raw_requests:
field = str(item.get("field", "")).strip()
if not field:
continue
try:
window = int(item.get("window", 1))
except (TypeError, ValueError):
window = 1
window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120)))
key = (field, window)
if key in delivered_requests:
skipped.append(field)
continue
requests.append(DataRequest(field=field, window=window))
if not requests:
return {
"status": "ok",
"round": round_idx + 1,
"results": [],
"skipped": skipped,
}, set()
lines, payload, delivered = self._fulfill_data_requests(context, requests)
if payload:
context.raw.setdefault("supplement_data", []).extend(payload)
context.raw.setdefault("supplement_rounds", []).append(
{
"round": round_idx + 1,
"requests": [req.__dict__ for req in requests],
"data": payload,
"notes": lines,
}
)
if lines:
context.raw.setdefault("supplement_notes", []).append(
{
"round": round_idx + 1,
"lines": lines,
}
)
response_payload = {
"status": "ok",
"round": round_idx + 1,
"results": payload,
"notes": lines,
"skipped": skipped,
}
return response_payload, delivered
def _build_tools(self) -> List[Dict[str, Any]]:
max_window = getattr(self._broker, "MAX_WINDOW", 120)
return [
{
"type": "function",
"function": {
"name": "fetch_data",
"description": (
"根据字段请求数据库中的最新值或时间序列。支持 table.column 格式的字段,"
"window 表示希望返回的最近数据点数量。"
),
"parameters": {
"type": "object",
"properties": {
"requests": {
"type": "array",
"items": {
"type": "object",
"properties": {
"field": {
"type": "string",
"description": "数据字段,格式为 table.column",
},
"window": {
"type": "integer",
"minimum": 1,
"maximum": max_window,
"description": "返回最近多少个数据点,默认为 1",
},
},
"required": ["field"],
},
"minItems": 1,
}
},
"required": ["requests"],
},
},
}
]
def _analyze_legacy(
self,
context: DepartmentContext,
system_prompt: str,
) -> DepartmentDecision:
prompt = department_prompt(self.settings, context)
llm_cfg = self._get_llm_config()
try:
response = run_llm_with_config(llm_cfg, prompt, system=system_prompt)
except Exception as exc: # noqa: BLE001
LOGGER.exception(
"部门 %s 调用 LLM 失败:%s",
self.settings.code,
exc,
extra=LOG_EXTRA,
)
return DepartmentDecision(
department=self.settings.code,
action=AgentAction.HOLD,
confidence=0.0,
summary=f"LLM 调用失败:{exc}",
raw_response=str(exc),
)
context.raw["supplement_transcript"] = [response]
decision_data = _parse_department_response(response)
action = _normalize_action(decision_data.get("action"))
confidence = _clamp_float(decision_data.get("confidence"), default=0.5)
summary = decision_data.get("summary") or decision_data.get("reason") or ""
signals = decision_data.get("signals") or decision_data.get("rationale") or []
if isinstance(signals, str):
signals = [signals]
risks = decision_data.get("risks") or decision_data.get("warnings") or []
if isinstance(risks, str):
risks = [risks]
decision = DepartmentDecision(
department=self.settings.code,
action=action,
confidence=confidence,
summary=summary or "未提供摘要",
signals=[str(sig) for sig in signals if sig],
risks=[str(risk) for risk in risks if risk],
raw_response=response,
supplements=list(context.raw.get("supplement_data", [])),
dialogue=[response],
)
return decision
def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
if not isinstance(context.features, dict):
context.features = dict(context.features or {})
if not isinstance(context.market_snapshot, dict):
context.market_snapshot = dict(context.market_snapshot or {})
raw = dict(context.raw or {})
scope_values = raw.get("scope_values")
if scope_values is not None and not isinstance(scope_values, dict):
raw["scope_values"] = dict(scope_values)
context.raw = raw
return context
def _parse_data_requests(payload: Mapping[str, Any]) -> List[DataRequest]:
raw_requests = payload.get("data_requests")
requests: List[DataRequest] = []
if not isinstance(raw_requests, list):
return requests
seen: set[Tuple[str, int]] = set()
for item in raw_requests:
field = ""
window = 1
if isinstance(item, str):
field = item.strip()
elif isinstance(item, Mapping):
candidate = item.get("field")
if candidate is None:
continue
field = str(candidate).strip()
try:
window = int(item.get("window", 1))
except (TypeError, ValueError):
window = 1
else:
continue
if not field:
continue
window = max(1, window)
key = (field, window)
if key in seen:
continue
seen.add(key)
requests.append(DataRequest(field=field, window=window))
return requests
def _parse_tool_arguments(payload: Any) -> Dict[str, Any]:
if isinstance(payload, dict):
return dict(payload)
if isinstance(payload, str):
try:
data = json.loads(payload)
except json.JSONDecodeError:
LOGGER.debug("工具参数解析失败:%s", payload, extra=LOG_EXTRA)
return {}
if isinstance(data, dict):
return data
return {}
def _message_to_text(message: Mapping[str, Any]) -> str:
content = message.get("content")
if isinstance(content, list):
parts: List[str] = []
for item in content:
if isinstance(item, Mapping) and "text" in item:
parts.append(str(item.get("text", "")))
else:
parts.append(str(item))
if parts:
return "".join(parts)
elif isinstance(content, str) and content.strip():
return content
tool_calls = message.get("tool_calls")
if tool_calls:
return json.dumps({"tool_calls": tool_calls}, ensure_ascii=False)
return ""
def _extract_message_content(message: Mapping[str, Any]) -> str:
content = message.get("content")
if isinstance(content, list):
texts = [
str(item.get("text", ""))
for item in content
if isinstance(item, Mapping) and "text" in item
]
if texts:
return "".join(texts)
if isinstance(content, str):
return content
return json.dumps(message, ensure_ascii=False)
def _build_context_lookup(
context: DepartmentContext,
) -> Tuple[Dict[str, Any], Dict[str, Tuple[str, str]], Dict[str, List[Tuple[str, float]]]]:
values: Dict[str, Any] = {}
db_alias: Dict[str, Tuple[str, str]] = {}
series_map: Dict[str, List[Tuple[str, float]]] = {}
for source in (context.features or {}, context.market_snapshot or {}):
for key, value in source.items():
values[str(key)] = value
scope_values = context.raw.get("scope_values") or {}
for key, value in scope_values.items():
key_str = str(key)
values[key_str] = value
if "." in key_str:
table, column = key_str.split(".", 1)
db_alias.setdefault(column, (table, column))
db_alias.setdefault(key_str, (table, column))
values.setdefault(column, value)
close_series = context.raw.get("close_series") or []
if isinstance(close_series, list) and close_series:
series_map["close"] = close_series
series_map["daily.close"] = close_series
turnover_series = context.raw.get("turnover_series") or []
if isinstance(turnover_series, list) and turnover_series:
series_map["turnover_rate"] = turnover_series
series_map["daily_basic.turnover_rate"] = turnover_series
return values, db_alias, series_map
class DepartmentManager:
"""Orchestrates all departments defined in configuration."""
def __init__(self, config: AppConfig) -> None:
self.config = config
self.agents: Dict[str, DepartmentAgent] = {
code: DepartmentAgent(settings, self._resolve_llm)
for code, settings in config.departments.items()
}
def evaluate(self, context: DepartmentContext) -> Dict[str, DepartmentDecision]:
results: Dict[str, DepartmentDecision] = {}
for code, agent in self.agents.items():
raw_base = dict(context.raw or {})
if "scope_values" in raw_base:
raw_base["scope_values"] = dict(raw_base.get("scope_values") or {})
dept_context = DepartmentContext(
ts_code=context.ts_code,
trade_date=context.trade_date,
features=dict(context.features or {}),
market_snapshot=dict(context.market_snapshot or {}),
raw=raw_base,
)
results[code] = agent.analyze(dept_context)
return results
def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig:
return settings.llm
def _parse_department_response(text: str) -> Dict[str, Any]:
"""Extract a JSON object from the LLM response if possible."""
cleaned = text.strip()
candidate = None
if cleaned.startswith("{") and cleaned.endswith("}"):
candidate = cleaned
else:
start = cleaned.find("{")
end = cleaned.rfind("}")
if start != -1 and end != -1 and end > start:
candidate = cleaned[start : end + 1]
if candidate:
try:
return json.loads(candidate)
except json.JSONDecodeError:
LOGGER.debug("部门响应 JSON 解析失败,返回原始文本", extra=LOG_EXTRA)
return {"summary": cleaned}
def _normalize_action(value: Any) -> AgentAction:
if isinstance(value, str):
upper = value.strip().upper()
mapping = {
"BUY": AgentAction.BUY_M,
"BUY_S": AgentAction.BUY_S,
"BUY_M": AgentAction.BUY_M,
"BUY_L": AgentAction.BUY_L,
"SELL": AgentAction.SELL,
"HOLD": AgentAction.HOLD,
}
if upper in mapping:
return mapping[upper]
if "SELL" in upper:
return AgentAction.SELL
if "BUY" in upper:
return AgentAction.BUY_M
return AgentAction.HOLD
def _clamp_float(value: Any, default: float = 0.5) -> float:
try:
num = float(value)
except (TypeError, ValueError):
return default
return max(0.0, min(1.0, num))