This commit is contained in:
sam 2025-09-28 21:22:15 +08:00
parent 0a2742b869
commit ca7b249c2c
4 changed files with 497 additions and 103 deletions

View File

@ -36,9 +36,10 @@
1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description``data_scope``department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。
2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。
3. **双阶段 LLM 工作流**:第一阶段让 LLM 输出结构化 `data_requests`,服务端取数后将摘要回填,第二阶段再生成最终行动与解释,形成闭环。
4. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。
4. **函数式工具调用**DeepSeek 等 OpenAI 兼容模型已通过 function calling 接口接入 `fetch_data` 工具LLM 按 schema 返回字段/窗口请求,系统使用 `DataBroker` 校验并补数后回传 tool result再继续对话生成最终意见避免字段错误与手写 JSON 校验。
5. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。
目前部门 LLM 已支持通过在 JSON 中返回 `data_requests` 触发追加查询:系统会使用 `DataBroker` 验证字段后补充最近数据窗口,再带着查询结果进入下一轮提示,从而形成“请求→取数→复议”的闭环。
目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型按 schema 声明需要的 `table.column` 字段与窗口,系统使用 `DataBroker` 验证后补齐数据并作为工具响应回传,再带着查询结果进入下一轮提示,从而形成“请求 取数 复议”的闭环。
上述调整可在单个部门先行做 PoC验证闭环能力后再推广至全部角色。

View File

@ -6,11 +6,11 @@ from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
from app.agents.base import AgentAction
from app.llm.client import run_llm_with_config
from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError
from app.llm.prompts import department_prompt
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
from app.utils.logging import get_logger
from app.utils.data_access import DataBroker, parse_field_path
from app.utils.data_access import DataBroker
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "department"}
@ -85,74 +85,95 @@ class DepartmentAgent:
"你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。"
)
llm_cfg = self._get_llm_config()
supplement_chunks: List[str] = []
if llm_cfg.strategy not in (None, "", "single") or llm_cfg.ensemble:
LOGGER.warning(
"部门 %s 当前配置不支持函数调用模式,回退至传统提示",
self.settings.code,
extra=LOG_EXTRA,
)
return self._analyze_legacy(mutable_context, system_prompt)
tools = self._build_tools()
messages: List[Dict[str, object]] = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append(
{
"role": "user",
"content": department_prompt(self.settings, mutable_context),
}
)
transcript: List[str] = []
delivered_requests = {
delivered_requests: set[Tuple[str, int]] = {
(field, 1)
for field in (mutable_context.raw.get("scope_values") or {}).keys()
}
response = ""
decision_data: Dict[str, Any] = {}
primary_endpoint = llm_cfg.primary
final_message: Optional[Dict[str, Any]] = None
for round_idx in range(self._max_rounds):
supplement_text = "\n\n".join(chunk for chunk in supplement_chunks if chunk)
prompt = department_prompt(self.settings, mutable_context, supplements=supplement_text)
try:
response = run_llm_with_config(llm_cfg, prompt, system=system_prompt)
except Exception as exc: # noqa: BLE001
LOGGER.exception(
"部门 %s 调用 LLM 失败:%s",
response = call_endpoint_with_messages(
primary_endpoint,
messages,
tools=tools,
tool_choice="auto",
)
except LLMError as exc:
LOGGER.warning(
"部门 %s 函数调用失败,回退传统提示:%s",
self.settings.code,
exc,
extra=LOG_EXTRA,
)
return DepartmentDecision(
department=self.settings.code,
action=AgentAction.HOLD,
confidence=0.0,
summary=f"LLM 调用失败:{exc}",
raw_response=str(exc),
)
return self._analyze_legacy(mutable_context, system_prompt)
transcript.append(response)
decision_data = _parse_department_response(response)
data_requests = _parse_data_requests(decision_data)
filtered_requests = [
req
for req in data_requests
if (req.field, req.window) not in delivered_requests
]
choice = (response.get("choices") or [{}])[0]
message = choice.get("message", {})
transcript.append(_message_to_text(message))
if filtered_requests and round_idx < self._max_rounds - 1:
lines, payload, delivered = self._fulfill_data_requests(
mutable_context, filtered_requests
tool_calls = message.get("tool_calls") or []
if tool_calls:
for call in tool_calls:
tool_response, delivered = self._handle_tool_call(
mutable_context,
call,
delivered_requests,
round_idx,
)
if payload:
supplement_chunks.append(
f"回合 {round_idx + 1} 追加数据:\n" + "\n".join(lines)
transcript.append(
json.dumps({"tool_response": tool_response}, ensure_ascii=False)
)
mutable_context.raw.setdefault("supplement_data", []).extend(payload)
mutable_context.raw.setdefault("supplement_rounds", []).append(
messages.append(
{
"round": round_idx + 1,
"requests": [req.__dict__ for req in filtered_requests],
"data": payload,
"role": "tool",
"tool_call_id": call.get("id"),
"name": call.get("function", {}).get("name"),
"content": json.dumps(tool_response, ensure_ascii=False),
}
)
delivered_requests.update(delivered)
decision_data.pop("data_requests", None)
continue
LOGGER.debug(
"部门 %s 数据请求无结果:%s",
self.settings.code,
filtered_requests,
extra=LOG_EXTRA,
)
decision_data.pop("data_requests", None)
final_message = message
break
if final_message is None:
LOGGER.warning(
"部门 %s 函数调用达到轮次上限仍未返回文本,使用最后一次消息",
self.settings.code,
extra=LOG_EXTRA,
)
final_message = message
mutable_context.raw["supplement_transcript"] = list(transcript)
content_text = _extract_message_content(final_message)
decision_data = _parse_department_response(content_text)
action = _normalize_action(decision_data.get("action"))
confidence = _clamp_float(decision_data.get("confidence"), default=0.5)
summary = decision_data.get("summary") or decision_data.get("reason") or ""
@ -170,7 +191,7 @@ class DepartmentAgent:
summary=summary or "未提供摘要",
signals=[str(sig) for sig in signals if sig],
risks=[str(risk) for risk in risks if risk],
raw_response="\n\n".join(transcript) if transcript else response,
raw_response=content_text,
supplements=list(mutable_context.raw.get("supplement_data", [])),
dialogue=list(transcript),
)
@ -208,16 +229,16 @@ class DepartmentAgent:
field = req.field.strip()
if not field:
continue
resolved = self._broker.resolve_field(field)
if not resolved:
lines.append(f"- {field}: 字段不存在或不可用")
continue
if req.window <= 1:
if field not in latest_fields:
latest_fields.append(field)
delivered.add((field, 1))
continue
parsed = parse_field_path(field)
if not parsed:
lines.append(f"- {field}: 字段不合法,已忽略")
continue
series_requests.append((req, parsed))
series_requests.append((req, resolved))
delivered.add((field, req.window))
if latest_fields:
@ -251,11 +272,186 @@ class DepartmentAgent:
lines.append(
f"- {req.field} (window={req.window}): (数据缺失)"
)
payload.append({"field": req.field, "window": req.window, "values": series})
payload.append(
{
"field": req.field,
"window": req.window,
"values": [
{"trade_date": dt, "value": val}
for dt, val in series
],
}
)
return lines, payload, delivered
def _handle_tool_call(
self,
context: DepartmentContext,
call: Mapping[str, Any],
delivered_requests: set[Tuple[str, int]],
round_idx: int,
) -> Tuple[Dict[str, Any], set[Tuple[str, int]]]:
function_block = call.get("function") or {}
name = function_block.get("name") or ""
if name != "fetch_data":
LOGGER.warning(
"部门 %s 收到未知工具调用:%s",
self.settings.code,
name,
extra=LOG_EXTRA,
)
return {
"status": "error",
"message": f"未知工具 {name}",
}, set()
args = _parse_tool_arguments(function_block.get("arguments"))
raw_requests = args.get("requests") or []
requests: List[DataRequest] = []
skipped: List[str] = []
for item in raw_requests:
field = str(item.get("field", "")).strip()
if not field:
continue
try:
window = int(item.get("window", 1))
except (TypeError, ValueError):
window = 1
window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120)))
key = (field, window)
if key in delivered_requests:
skipped.append(field)
continue
requests.append(DataRequest(field=field, window=window))
if not requests:
return {
"status": "ok",
"round": round_idx + 1,
"results": [],
"skipped": skipped,
}, set()
lines, payload, delivered = self._fulfill_data_requests(context, requests)
if payload:
context.raw.setdefault("supplement_data", []).extend(payload)
context.raw.setdefault("supplement_rounds", []).append(
{
"round": round_idx + 1,
"requests": [req.__dict__ for req in requests],
"data": payload,
"notes": lines,
}
)
if lines:
context.raw.setdefault("supplement_notes", []).append(
{
"round": round_idx + 1,
"lines": lines,
}
)
response_payload = {
"status": "ok",
"round": round_idx + 1,
"results": payload,
"notes": lines,
"skipped": skipped,
}
return response_payload, delivered
def _build_tools(self) -> List[Dict[str, Any]]:
max_window = getattr(self._broker, "MAX_WINDOW", 120)
return [
{
"type": "function",
"function": {
"name": "fetch_data",
"description": (
"根据字段请求数据库中的最新值或时间序列。支持 table.column 格式的字段,"
"window 表示希望返回的最近数据点数量。"
),
"parameters": {
"type": "object",
"properties": {
"requests": {
"type": "array",
"items": {
"type": "object",
"properties": {
"field": {
"type": "string",
"description": "数据字段,格式为 table.column",
},
"window": {
"type": "integer",
"minimum": 1,
"maximum": max_window,
"description": "返回最近多少个数据点,默认为 1",
},
},
"required": ["field"],
},
"minItems": 1,
}
},
"required": ["requests"],
},
},
}
]
def _analyze_legacy(
self,
context: DepartmentContext,
system_prompt: str,
) -> DepartmentDecision:
prompt = department_prompt(self.settings, context)
llm_cfg = self._get_llm_config()
try:
response = run_llm_with_config(llm_cfg, prompt, system=system_prompt)
except Exception as exc: # noqa: BLE001
LOGGER.exception(
"部门 %s 调用 LLM 失败:%s",
self.settings.code,
exc,
extra=LOG_EXTRA,
)
return DepartmentDecision(
department=self.settings.code,
action=AgentAction.HOLD,
confidence=0.0,
summary=f"LLM 调用失败:{exc}",
raw_response=str(exc),
)
context.raw["supplement_transcript"] = [response]
decision_data = _parse_department_response(response)
action = _normalize_action(decision_data.get("action"))
confidence = _clamp_float(decision_data.get("confidence"), default=0.5)
summary = decision_data.get("summary") or decision_data.get("reason") or ""
signals = decision_data.get("signals") or decision_data.get("rationale") or []
if isinstance(signals, str):
signals = [signals]
risks = decision_data.get("risks") or decision_data.get("warnings") or []
if isinstance(risks, str):
risks = [risks]
decision = DepartmentDecision(
department=self.settings.code,
action=action,
confidence=confidence,
summary=summary or "未提供摘要",
signals=[str(sig) for sig in signals if sig],
risks=[str(risk) for risk in risks if risk],
raw_response=response,
supplements=list(context.raw.get("supplement_data", [])),
dialogue=[response],
)
return decision
def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
if not isinstance(context.features, dict):
context.features = dict(context.features or {})
@ -302,6 +498,54 @@ def _parse_data_requests(payload: Mapping[str, Any]) -> List[DataRequest]:
return requests
def _parse_tool_arguments(payload: Any) -> Dict[str, Any]:
if isinstance(payload, dict):
return dict(payload)
if isinstance(payload, str):
try:
data = json.loads(payload)
except json.JSONDecodeError:
LOGGER.debug("工具参数解析失败:%s", payload, extra=LOG_EXTRA)
return {}
if isinstance(data, dict):
return data
return {}
def _message_to_text(message: Mapping[str, Any]) -> str:
content = message.get("content")
if isinstance(content, list):
parts: List[str] = []
for item in content:
if isinstance(item, Mapping) and "text" in item:
parts.append(str(item.get("text", "")))
else:
parts.append(str(item))
if parts:
return "".join(parts)
elif isinstance(content, str) and content.strip():
return content
tool_calls = message.get("tool_calls")
if tool_calls:
return json.dumps({"tool_calls": tool_calls}, ensure_ascii=False)
return ""
def _extract_message_content(message: Mapping[str, Any]) -> str:
content = message.get("content")
if isinstance(content, list):
texts = [
str(item.get("text", ""))
for item in content
if isinstance(item, Mapping) and "text" in item
]
if texts:
return "".join(texts)
if isinstance(content, str):
return content
return json.dumps(message, ensure_ascii=False)
class DepartmentManager:
"""Orchestrates all departments defined in configuration."""

View File

@ -11,6 +11,8 @@ import requests
from app.utils.config import (
DEFAULT_LLM_BASE_URLS,
DEFAULT_LLM_MODELS,
DEFAULT_LLM_TEMPERATURES,
DEFAULT_LLM_TIMEOUTS,
LLMConfig,
LLMEndpoint,
get_config,
@ -34,8 +36,8 @@ def _default_model(provider: str) -> str:
return DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"])
def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, str]]:
messages: List[Dict[str, str]] = []
def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, object]]:
messages: List[Dict[str, object]] = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
@ -70,38 +72,39 @@ def _request_ollama(
return str(content)
def _request_openai(
model: str,
prompt: str,
def _request_openai_chat(
*,
base_url: str,
api_key: str,
model: str,
messages: List[Dict[str, object]],
temperature: float,
timeout: float,
system: Optional[str],
) -> str:
tools: Optional[List[Dict[str, object]]] = None,
tool_choice: Optional[object] = None,
) -> Dict[str, object]:
url = f"{base_url.rstrip('/')}/v1/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
payload = {
payload: Dict[str, object] = {
"model": model,
"messages": _build_messages(prompt, system),
"messages": messages,
"temperature": temperature,
}
if tools:
payload["tools"] = tools
if tool_choice is not None:
payload["tool_choice"] = tool_choice
LOGGER.debug("调用 OpenAI 兼容接口: %s %s", model, url, extra=LOG_EXTRA)
response = requests.post(url, headers=headers, json=payload, timeout=timeout)
if response.status_code != 200:
raise LLMError(f"OpenAI API 调用失败: {response.status_code} {response.text}")
data = response.json()
try:
return data["choices"][0]["message"]["content"].strip()
except (KeyError, IndexError) as exc:
raise LLMError(f"OpenAI 响应解析失败: {json.dumps(data, ensure_ascii=False)}") from exc
return response.json()
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
def _prepare_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
cfg = get_config()
provider_key = (endpoint.provider or "ollama").lower()
provider_cfg = cfg.llm_providers.get(provider_key)
@ -128,14 +131,30 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) ->
else:
base_url = base_url or _default_base_url(provider_key)
model = model or _default_model(provider_key)
if temperature is None:
temperature = DEFAULT_LLM_TEMPERATURES.get(provider_key, 0.2)
if timeout is None:
timeout = DEFAULT_LLM_TIMEOUTS.get(provider_key, 30.0)
mode = "ollama" if provider_key == "ollama" else "openai"
temperature = max(0.0, min(float(temperature), 2.0))
timeout = max(5.0, float(timeout))
return {
"provider_key": provider_key,
"mode": mode,
"base_url": base_url,
"api_key": api_key,
"model": model,
"temperature": max(0.0, min(float(temperature), 2.0)),
"timeout": max(5.0, float(timeout)),
"prompt_template": prompt_template,
}
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
resolved = _prepare_endpoint(endpoint)
provider_key = resolved["provider_key"]
mode = resolved["mode"]
prompt_template = resolved["prompt_template"]
if prompt_template:
try:
@ -143,6 +162,40 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) ->
except Exception: # noqa: BLE001
LOGGER.warning("Prompt 模板格式化失败,使用原始 prompt", extra=LOG_EXTRA)
messages = _build_messages(prompt, system)
response = call_endpoint_with_messages(
endpoint,
messages,
tools=None,
)
if mode == "ollama":
message = response.get("message") or {}
content = message.get("content", "")
if isinstance(content, list):
return "".join(chunk.get("text", "") or chunk.get("content", "") for chunk in content)
return str(content)
try:
return response["choices"][0]["message"]["content"].strip()
except (KeyError, IndexError) as exc:
raise LLMError(f"OpenAI 响应解析失败: {json.dumps(response, ensure_ascii=False)}") from exc
def call_endpoint_with_messages(
endpoint: LLMEndpoint,
messages: List[Dict[str, object]],
*,
tools: Optional[List[Dict[str, object]]] = None,
tool_choice: Optional[object] = None,
) -> Dict[str, object]:
resolved = _prepare_endpoint(endpoint)
provider_key = resolved["provider_key"]
mode = resolved["mode"]
base_url = resolved["base_url"]
model = resolved["model"]
temperature = resolved["temperature"]
timeout = resolved["timeout"]
api_key = resolved["api_key"]
LOGGER.info(
"触发 LLM 请求provider=%s model=%s base=%s",
provider_key,
@ -151,28 +204,36 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) ->
extra=LOG_EXTRA,
)
if mode != "ollama":
if mode == "ollama":
if tools:
raise LLMError("当前 provider 不支持函数调用/工具模式")
payload = {
"model": model,
"messages": messages,
"stream": False,
"options": {"temperature": temperature},
}
response = requests.post(
f"{base_url.rstrip('/')}/api/chat",
json=payload,
timeout=timeout,
)
if response.status_code != 200:
raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}")
return response.json()
if not api_key:
raise LLMError(f"缺少 {provider_key} API Key (model={model})")
return _request_openai(
model,
prompt,
return _request_openai_chat(
base_url=base_url,
api_key=api_key,
model=model,
messages=messages,
temperature=temperature,
timeout=timeout,
system=system,
tools=tools,
tool_choice=tool_choice,
)
if base_url:
return _request_ollama(
model,
prompt,
base_url=base_url,
temperature=temperature,
timeout=timeout,
system=system,
)
raise LLMError(f"不支持的 LLM provider: {endpoint.provider}")
def _normalize_response(text: str) -> str:

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Dict, Iterable, List, Sequence, Tuple
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
from .db import db_session
from .logging import get_logger
@ -42,6 +42,29 @@ def parse_field_path(path: str) -> Tuple[str, str] | None:
class DataBroker:
"""Lightweight data access helper for agent/LLM consumption."""
FIELD_ALIASES: Dict[str, Dict[str, str]] = {
"daily": {
"volume": "vol",
"vol": "vol",
"turnover": "amount",
},
"daily_basic": {
"turnover": "turnover_rate",
"turnover_rate": "turnover_rate",
"turnover_rate_f": "turnover_rate_f",
"volume_ratio": "volume_ratio",
"pe": "pe",
"pb": "pb",
"ps": "ps",
"ps_ttm": "ps_ttm",
},
"stk_limit": {
"up": "up_limit",
"down": "down_limit",
},
}
MAX_WINDOW: int = 120
def fetch_latest(
self,
ts_code: str,
@ -51,16 +74,18 @@ class DataBroker:
"""Fetch the latest value (<= trade_date) for each requested field."""
grouped: Dict[str, List[str]] = {}
field_map: Dict[Tuple[str, str], List[str]] = {}
for item in fields:
if not item:
continue
normalized = _safe_split(str(item))
if not normalized:
resolved = self.resolve_field(str(item))
if not resolved:
continue
table, column = normalized
table, column = resolved
grouped.setdefault(table, [])
if column not in grouped[table]:
grouped[table].append(column)
field_map.setdefault((table, column), []).append(str(item))
if not grouped:
return {}
@ -91,8 +116,8 @@ class DataBroker:
value = row[column]
if value is None:
continue
key = f"{table}.{column}"
results[key] = float(value)
for original in field_map.get((table, column), [f"{table}.{column}"]):
results[original] = float(value)
return results
def fetch_series(
@ -107,10 +132,19 @@ class DataBroker:
if window <= 0:
return []
if not (_is_safe_identifier(table) and _is_safe_identifier(column)):
window = min(window, self.MAX_WINDOW)
resolved_field = self.resolve_field(f"{table}.{column}")
if not resolved_field:
LOGGER.debug(
"时间序列字段不存在 table=%s column=%s",
table,
column,
extra=LOG_EXTRA,
)
return []
table, resolved = resolved_field
query = (
f"SELECT trade_date, {column} FROM {table} "
f"SELECT trade_date, {resolved} FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT ?"
)
@ -128,7 +162,7 @@ class DataBroker:
return []
series: List[Tuple[str, float]] = []
for row in rows:
value = row[column]
value = row[resolved]
if value is None:
continue
series.append((row["trade_date"], float(value)))
@ -163,3 +197,57 @@ class DataBroker:
)
return False
return row is not None
def resolve_field(self, field: str) -> Optional[Tuple[str, str]]:
normalized = _safe_split(field)
if not normalized:
return None
table, column = normalized
resolved = self._resolve_column(table, column)
if not resolved:
LOGGER.debug(
"字段不存在 table=%s column=%s",
table,
column,
extra=LOG_EXTRA,
)
return None
return table, resolved
def _get_table_columns(self, table: str) -> Optional[set[str]]:
if not _is_safe_identifier(table):
return None
cache = getattr(self, "_column_cache", None)
if cache is None:
cache = {}
self._column_cache = cache
if table in cache:
return cache[table]
try:
with db_session(read_only=True) as conn:
rows = conn.execute(f"PRAGMA table_info({table})").fetchall()
except Exception as exc: # noqa: BLE001
LOGGER.debug("获取表字段失败 table=%s err=%s", table, exc, extra=LOG_EXTRA)
cache[table] = None
return None
if not rows:
cache[table] = None
return None
columns = {row["name"] for row in rows if row["name"]}
cache[table] = columns
return columns
def _resolve_column(self, table: str, column: str) -> Optional[str]:
columns = self._get_table_columns(table)
if columns is None:
return None
alias_map = self.FIELD_ALIASES.get(table, {})
candidate = alias_map.get(column, column)
if candidate in columns:
return candidate
# Try lower-case or fallback alias normalization
lowered = candidate.lower()
for name in columns:
if name.lower() == lowered:
return name
return None