"""Department-level LLM agents coordinating multi-model decisions.""" from __future__ import annotations import json from dataclasses import dataclass, field from typing import Any, Callable, ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple from app.agents.base import AgentAction from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError from app.llm.prompts import department_prompt from app.utils.config import AppConfig, DepartmentSettings, LLMConfig from app.utils.logging import get_logger, get_conversation_logger from app.utils.data_access import DataBroker LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "department"} CONV_LOGGER = get_conversation_logger() @dataclass class TableRequest: name: str window: int = 1 trade_date: Optional[str] = None @dataclass class DepartmentContext: """Structured data fed into a department for decision making.""" ts_code: str trade_date: str features: Mapping[str, Any] = field(default_factory=dict) market_snapshot: Mapping[str, Any] = field(default_factory=dict) raw: Mapping[str, Any] = field(default_factory=dict) @dataclass class DepartmentDecision: """Result produced by a department agent.""" department: str action: AgentAction confidence: float summary: str raw_response: str signals: List[str] = field(default_factory=list) risks: List[str] = field(default_factory=list) supplements: List[Dict[str, Any]] = field(default_factory=list) dialogue: List[str] = field(default_factory=list) def to_dict(self) -> Dict[str, Any]: return { "department": self.department, "action": self.action.value, "confidence": self.confidence, "summary": self.summary, "signals": self.signals, "risks": self.risks, "raw_response": self.raw_response, "supplements": self.supplements, "dialogue": self.dialogue, } class DepartmentAgent: """Wraps LLM ensemble logic for a single analytical department.""" ALLOWED_TABLES: ClassVar[List[str]] = ["daily", "daily_basic"] def __init__( self, settings: DepartmentSettings, resolver: Optional[Callable[[DepartmentSettings], LLMConfig]] = None, ) -> None: self.settings = settings self._resolver = resolver self._broker = DataBroker() self._max_rounds = 3 def _get_llm_config(self) -> LLMConfig: if self._resolver: return self._resolver(self.settings) return self.settings.llm def analyze(self, context: DepartmentContext) -> DepartmentDecision: mutable_context = _ensure_mutable_context(context) system_prompt = ( "你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。" ) llm_cfg = self._get_llm_config() if llm_cfg.strategy not in (None, "", "single") or llm_cfg.ensemble: LOGGER.warning( "部门 %s 当前配置不支持函数调用模式,回退至传统提示", self.settings.code, extra=LOG_EXTRA, ) return self._analyze_legacy(mutable_context, system_prompt) tools = self._build_tools() messages: List[Dict[str, object]] = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append( { "role": "user", "content": department_prompt(self.settings, mutable_context), } ) transcript: List[str] = [] delivered_requests: set[Tuple[str, int, str]] = set() primary_endpoint = llm_cfg.primary final_message: Optional[Dict[str, Any]] = None CONV_LOGGER.info( "dept=%s ts_code=%s trade_date=%s start", self.settings.code, context.ts_code, context.trade_date, ) for round_idx in range(self._max_rounds): try: response = call_endpoint_with_messages( primary_endpoint, messages, tools=tools, tool_choice="auto", ) except LLMError as exc: LOGGER.warning( "部门 %s 函数调用失败,回退传统提示:%s", self.settings.code, exc, extra=LOG_EXTRA, ) return self._analyze_legacy(mutable_context, system_prompt) choice = (response.get("choices") or [{}])[0] message = choice.get("message", {}) transcript.append(_message_to_text(message)) assistant_record: Dict[str, Any] = { "role": "assistant", "content": _extract_message_content(message), } if message.get("tool_calls"): assistant_record["tool_calls"] = message.get("tool_calls") messages.append(assistant_record) CONV_LOGGER.info( "dept=%s round=%s assistant=%s", self.settings.code, round_idx + 1, assistant_record, ) tool_calls = message.get("tool_calls") or [] if tool_calls: for call in tool_calls: tool_response, delivered = self._handle_tool_call( mutable_context, call, delivered_requests, round_idx, ) transcript.append( json.dumps({"tool_response": tool_response}, ensure_ascii=False) ) messages.append( { "role": "tool", "tool_call_id": call.get("id"), "content": json.dumps(tool_response, ensure_ascii=False), } ) delivered_requests.update(delivered) CONV_LOGGER.info( "dept=%s round=%s tool_call=%s response=%s", self.settings.code, round_idx + 1, call, tool_response, ) continue final_message = message break if final_message is None: LOGGER.warning( "部门 %s 函数调用达到轮次上限仍未返回文本,使用最后一次消息", self.settings.code, extra=LOG_EXTRA, ) final_message = message CONV_LOGGER.warning( "dept=%s rounds_exhausted last_message=%s", self.settings.code, final_message, ) mutable_context.raw["supplement_transcript"] = list(transcript) content_text = _extract_message_content(final_message) decision_data = _parse_department_response(content_text) action = _normalize_action(decision_data.get("action")) confidence = _clamp_float(decision_data.get("confidence"), default=0.5) summary = decision_data.get("summary") or decision_data.get("reason") or "" signals = decision_data.get("signals") or decision_data.get("rationale") or [] if isinstance(signals, str): signals = [signals] risks = decision_data.get("risks") or decision_data.get("warnings") or [] if isinstance(risks, str): risks = [risks] decision = DepartmentDecision( department=self.settings.code, action=action, confidence=confidence, summary=summary or "未提供摘要", signals=[str(sig) for sig in signals if sig], risks=[str(risk) for risk in risks if risk], raw_response=content_text, supplements=list(mutable_context.raw.get("supplement_data", [])), dialogue=list(transcript), ) LOGGER.debug( "部门 %s 决策:action=%s confidence=%.2f", self.settings.code, decision.action.value, decision.confidence, extra=LOG_EXTRA, ) CONV_LOGGER.info( "dept=%s decision action=%s confidence=%.2f summary=%s", self.settings.code, decision.action.value, decision.confidence, summary or "", ) return decision @staticmethod def _normalize_trade_date(value: str) -> str: if not isinstance(value, str): return str(value) return value.replace("-", "") def _fulfill_data_requests( self, context: DepartmentContext, requests: Sequence[TableRequest], ) -> Tuple[List[str], List[Dict[str, Any]], set[Tuple[str, int, str]]]: lines: List[str] = [] payload: List[Dict[str, Any]] = [] delivered: set[Tuple[str, int, str]] = set() ts_code = context.ts_code default_trade_date = self._normalize_trade_date(context.trade_date) for req in requests: table = (req.name or "").strip().lower() if not table: continue if table not in self.ALLOWED_TABLES: lines.append(f"- {table}: 不在允许的表列表中") continue trade_date = self._normalize_trade_date(req.trade_date or default_trade_date) window = max(1, min(req.window or 1, getattr(self._broker, "MAX_WINDOW", 120))) key = (table, window, trade_date) if key in delivered: lines.append(f"- {table}: 已返回窗口 {window} 的数据,跳过重复请求") continue rows = self._broker.fetch_table_rows(table, ts_code, trade_date, window) if rows: preview = ", ".join( f"{row.get('trade_date', 'NA')}" for row in rows[: min(len(rows), 5)] ) lines.append( f"- {table} (window={window} trade_date<= {trade_date}): 返回 {len(rows)} 行 {preview}" ) else: lines.append( f"- {table} (window={window} trade_date<= {trade_date}): (数据缺失)" ) payload.append( { "table": table, "window": window, "trade_date": trade_date, "rows": rows, } ) delivered.add(key) return lines, payload, delivered def _handle_tool_call( self, context: DepartmentContext, call: Mapping[str, Any], delivered_requests: set[Tuple[str, int, str]], round_idx: int, ) -> Tuple[Dict[str, Any], set[Tuple[str, int, str]]]: function_block = call.get("function") or {} name = function_block.get("name") or "" if name != "fetch_data": LOGGER.warning( "部门 %s 收到未知工具调用:%s", self.settings.code, name, extra=LOG_EXTRA, ) return { "status": "error", "message": f"未知工具 {name}", }, set() args = _parse_tool_arguments(function_block.get("arguments")) base_trade_date = self._normalize_trade_date( args.get("trade_date") or context.trade_date ) raw_requests = args.get("tables") or [] requests: List[TableRequest] = [] skipped: List[str] = [] for item in raw_requests: name = str(item.get("name", "")).strip().lower() if not name: continue window_raw = item.get("window") try: window = int(window_raw) if window_raw is not None else 1 except (TypeError, ValueError): window = 1 window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120))) override_date = item.get("trade_date") req_date = self._normalize_trade_date(override_date or base_trade_date) key = (name, window, req_date) if key in delivered_requests: skipped.append(name) continue requests.append(TableRequest(name=name, window=window, trade_date=req_date)) if not requests: return { "status": "ok", "round": round_idx + 1, "results": [], "skipped": skipped, }, set() lines, payload, delivered = self._fulfill_data_requests(context, requests) if payload: context.raw.setdefault("supplement_data", []).extend(payload) context.raw.setdefault("supplement_rounds", []).append( { "round": round_idx + 1, "requests": [req.__dict__ for req in requests], "data": payload, "notes": lines, } ) if lines: context.raw.setdefault("supplement_notes", []).append( { "round": round_idx + 1, "lines": lines, } ) response_payload = { "status": "ok", "round": round_idx + 1, "results": payload, "notes": lines, "skipped": skipped, } return response_payload, delivered def _build_tools(self) -> List[Dict[str, Any]]: max_window = getattr(self._broker, "MAX_WINDOW", 120) return [ { "type": "function", "function": { "name": "fetch_data", "description": ( "根据表名请求指定交易日及窗口的历史数据。当前仅支持 'daily' 与 'daily_basic' 表。" ), "parameters": { "type": "object", "properties": { "tables": { "type": "array", "items": { "type": "object", "properties": { "name": { "type": "string", "enum": self.ALLOWED_TABLES, "description": "表名,例如 daily 或 daily_basic", }, "window": { "type": "integer", "minimum": 1, "maximum": max_window, "description": "向前回溯的记录条数,默认为 1", }, "trade_date": { "type": "string", "pattern": r"^\\d{8}$", "description": "覆盖默认交易日(格式 YYYYMMDD)", }, }, "required": ["name"], }, "minItems": 1, }, "trade_date": { "type": "string", "pattern": r"^\\d{8}$", "description": "默认交易日(格式 YYYYMMDD)", }, }, "required": ["tables"], }, }, } ] def _analyze_legacy( self, context: DepartmentContext, system_prompt: str, ) -> DepartmentDecision: prompt = department_prompt(self.settings, context) llm_cfg = self._get_llm_config() try: response = run_llm_with_config(llm_cfg, prompt, system=system_prompt) except Exception as exc: # noqa: BLE001 LOGGER.exception( "部门 %s 调用 LLM 失败:%s", self.settings.code, exc, extra=LOG_EXTRA, ) CONV_LOGGER.error( "dept=%s legacy_call_failed err=%s", self.settings.code, exc, ) return DepartmentDecision( department=self.settings.code, action=AgentAction.HOLD, confidence=0.0, summary=f"LLM 调用失败:{exc}", raw_response=str(exc), ) context.raw["supplement_transcript"] = [response] CONV_LOGGER.info("dept=%s legacy_response=%s", self.settings.code, response) decision_data = _parse_department_response(response) action = _normalize_action(decision_data.get("action")) confidence = _clamp_float(decision_data.get("confidence"), default=0.5) summary = decision_data.get("summary") or decision_data.get("reason") or "" signals = decision_data.get("signals") or decision_data.get("rationale") or [] if isinstance(signals, str): signals = [signals] risks = decision_data.get("risks") or decision_data.get("warnings") or [] if isinstance(risks, str): risks = [risks] decision = DepartmentDecision( department=self.settings.code, action=action, confidence=confidence, summary=summary or "未提供摘要", signals=[str(sig) for sig in signals if sig], risks=[str(risk) for risk in risks if risk], raw_response=response, supplements=list(context.raw.get("supplement_data", [])), dialogue=[response], ) return decision def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext: if not isinstance(context.features, dict): context.features = dict(context.features or {}) if not isinstance(context.market_snapshot, dict): context.market_snapshot = dict(context.market_snapshot or {}) raw = dict(context.raw or {}) scope_values = raw.get("scope_values") if scope_values is not None and not isinstance(scope_values, dict): raw["scope_values"] = dict(scope_values) context.raw = raw return context def _parse_tool_arguments(payload: Any) -> Dict[str, Any]: if isinstance(payload, dict): return dict(payload) if isinstance(payload, str): try: data = json.loads(payload) except json.JSONDecodeError: LOGGER.debug("工具参数解析失败:%s", payload, extra=LOG_EXTRA) return {} if isinstance(data, dict): return data return {} def _message_to_text(message: Mapping[str, Any]) -> str: content = message.get("content") if isinstance(content, list): parts: List[str] = [] for item in content: if isinstance(item, Mapping) and "text" in item: parts.append(str(item.get("text", ""))) else: parts.append(str(item)) if parts: return "".join(parts) elif isinstance(content, str) and content.strip(): return content tool_calls = message.get("tool_calls") if tool_calls: return json.dumps({"tool_calls": tool_calls}, ensure_ascii=False) return "" def _extract_message_content(message: Mapping[str, Any]) -> str: content = message.get("content") if isinstance(content, list): texts = [ str(item.get("text", "")) for item in content if isinstance(item, Mapping) and "text" in item ] if texts: return "".join(texts) if isinstance(content, str): return content return json.dumps(message, ensure_ascii=False) class DepartmentManager: """Orchestrates all departments defined in configuration.""" def __init__(self, config: AppConfig) -> None: self.config = config self.agents: Dict[str, DepartmentAgent] = { code: DepartmentAgent(settings, self._resolve_llm) for code, settings in config.departments.items() } def evaluate(self, context: DepartmentContext) -> Dict[str, DepartmentDecision]: results: Dict[str, DepartmentDecision] = {} for code, agent in self.agents.items(): raw_base = dict(context.raw or {}) if "scope_values" in raw_base: raw_base["scope_values"] = dict(raw_base.get("scope_values") or {}) dept_context = DepartmentContext( ts_code=context.ts_code, trade_date=context.trade_date, features=dict(context.features or {}), market_snapshot=dict(context.market_snapshot or {}), raw=raw_base, ) results[code] = agent.analyze(dept_context) return results def _resolve_llm(self, settings: DepartmentSettings) -> LLMConfig: return settings.llm def _parse_department_response(text: str) -> Dict[str, Any]: """Extract a JSON object from the LLM response if possible.""" cleaned = text.strip() candidate = None if cleaned.startswith("{") and cleaned.endswith("}"): candidate = cleaned else: start = cleaned.find("{") end = cleaned.rfind("}") if start != -1 and end != -1 and end > start: candidate = cleaned[start : end + 1] if candidate: try: return json.loads(candidate) except json.JSONDecodeError: LOGGER.debug("部门响应 JSON 解析失败,返回原始文本", extra=LOG_EXTRA) return {"summary": cleaned} def _normalize_action(value: Any) -> AgentAction: if isinstance(value, str): upper = value.strip().upper() mapping = { "BUY": AgentAction.BUY_M, "BUY_S": AgentAction.BUY_S, "BUY_M": AgentAction.BUY_M, "BUY_L": AgentAction.BUY_L, "SELL": AgentAction.SELL, "HOLD": AgentAction.HOLD, } if upper in mapping: return mapping[upper] if "SELL" in upper: return AgentAction.SELL if "BUY" in upper: return AgentAction.BUY_M return AgentAction.HOLD def _clamp_float(value: Any, default: float = 0.5) -> float: try: num = float(value) except (TypeError, ValueError): return default return max(0.0, min(1.0, num))