update
This commit is contained in:
parent
0a2742b869
commit
ca7b249c2c
@ -36,9 +36,10 @@
|
||||
1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description` 与 `data_scope`,`department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。
|
||||
2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。
|
||||
3. **双阶段 LLM 工作流**:第一阶段让 LLM 输出结构化 `data_requests`,服务端取数后将摘要回填,第二阶段再生成最终行动与解释,形成闭环。
|
||||
4. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。
|
||||
4. **函数式工具调用**:DeepSeek 等 OpenAI 兼容模型已通过 function calling 接口接入 `fetch_data` 工具,LLM 按 schema 返回字段/窗口请求,系统使用 `DataBroker` 校验并补数后回传 tool result,再继续对话生成最终意见,避免字段错误与手写 JSON 校验。
|
||||
5. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。
|
||||
|
||||
目前部门 LLM 已支持通过在 JSON 中返回 `data_requests` 触发追加查询:系统会使用 `DataBroker` 验证字段后补充最近数据窗口,再带着查询结果进入下一轮提示,从而形成“请求→取数→复议”的闭环。
|
||||
目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型按 schema 声明需要的 `table.column` 字段与窗口,系统使用 `DataBroker` 验证后补齐数据并作为工具响应回传,再带着查询结果进入下一轮提示,从而形成“请求 → 取数 → 复议”的闭环。
|
||||
|
||||
上述调整可在单个部门先行做 PoC,验证闭环能力后再推广至全部角色。
|
||||
|
||||
|
||||
@ -6,11 +6,11 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
from app.agents.base import AgentAction
|
||||
from app.llm.client import run_llm_with_config
|
||||
from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError
|
||||
from app.llm.prompts import department_prompt
|
||||
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
|
||||
from app.utils.logging import get_logger
|
||||
from app.utils.data_access import DataBroker, parse_field_path
|
||||
from app.utils.data_access import DataBroker
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "department"}
|
||||
@ -85,74 +85,95 @@ class DepartmentAgent:
|
||||
"你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。"
|
||||
)
|
||||
llm_cfg = self._get_llm_config()
|
||||
supplement_chunks: List[str] = []
|
||||
|
||||
if llm_cfg.strategy not in (None, "", "single") or llm_cfg.ensemble:
|
||||
LOGGER.warning(
|
||||
"部门 %s 当前配置不支持函数调用模式,回退至传统提示",
|
||||
self.settings.code,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return self._analyze_legacy(mutable_context, system_prompt)
|
||||
|
||||
tools = self._build_tools()
|
||||
messages: List[Dict[str, object]] = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": department_prompt(self.settings, mutable_context),
|
||||
}
|
||||
)
|
||||
|
||||
transcript: List[str] = []
|
||||
delivered_requests = {
|
||||
delivered_requests: set[Tuple[str, int]] = {
|
||||
(field, 1)
|
||||
for field in (mutable_context.raw.get("scope_values") or {}).keys()
|
||||
}
|
||||
|
||||
response = ""
|
||||
decision_data: Dict[str, Any] = {}
|
||||
primary_endpoint = llm_cfg.primary
|
||||
final_message: Optional[Dict[str, Any]] = None
|
||||
|
||||
for round_idx in range(self._max_rounds):
|
||||
supplement_text = "\n\n".join(chunk for chunk in supplement_chunks if chunk)
|
||||
prompt = department_prompt(self.settings, mutable_context, supplements=supplement_text)
|
||||
try:
|
||||
response = run_llm_with_config(llm_cfg, prompt, system=system_prompt)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception(
|
||||
"部门 %s 调用 LLM 失败:%s",
|
||||
response = call_endpoint_with_messages(
|
||||
primary_endpoint,
|
||||
messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
except LLMError as exc:
|
||||
LOGGER.warning(
|
||||
"部门 %s 函数调用失败,回退传统提示:%s",
|
||||
self.settings.code,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return DepartmentDecision(
|
||||
department=self.settings.code,
|
||||
action=AgentAction.HOLD,
|
||||
confidence=0.0,
|
||||
summary=f"LLM 调用失败:{exc}",
|
||||
raw_response=str(exc),
|
||||
)
|
||||
return self._analyze_legacy(mutable_context, system_prompt)
|
||||
|
||||
transcript.append(response)
|
||||
decision_data = _parse_department_response(response)
|
||||
data_requests = _parse_data_requests(decision_data)
|
||||
filtered_requests = [
|
||||
req
|
||||
for req in data_requests
|
||||
if (req.field, req.window) not in delivered_requests
|
||||
]
|
||||
choice = (response.get("choices") or [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
transcript.append(_message_to_text(message))
|
||||
|
||||
if filtered_requests and round_idx < self._max_rounds - 1:
|
||||
lines, payload, delivered = self._fulfill_data_requests(
|
||||
mutable_context, filtered_requests
|
||||
tool_calls = message.get("tool_calls") or []
|
||||
if tool_calls:
|
||||
for call in tool_calls:
|
||||
tool_response, delivered = self._handle_tool_call(
|
||||
mutable_context,
|
||||
call,
|
||||
delivered_requests,
|
||||
round_idx,
|
||||
)
|
||||
if payload:
|
||||
supplement_chunks.append(
|
||||
f"回合 {round_idx + 1} 追加数据:\n" + "\n".join(lines)
|
||||
transcript.append(
|
||||
json.dumps({"tool_response": tool_response}, ensure_ascii=False)
|
||||
)
|
||||
mutable_context.raw.setdefault("supplement_data", []).extend(payload)
|
||||
mutable_context.raw.setdefault("supplement_rounds", []).append(
|
||||
messages.append(
|
||||
{
|
||||
"round": round_idx + 1,
|
||||
"requests": [req.__dict__ for req in filtered_requests],
|
||||
"data": payload,
|
||||
"role": "tool",
|
||||
"tool_call_id": call.get("id"),
|
||||
"name": call.get("function", {}).get("name"),
|
||||
"content": json.dumps(tool_response, ensure_ascii=False),
|
||||
}
|
||||
)
|
||||
delivered_requests.update(delivered)
|
||||
decision_data.pop("data_requests", None)
|
||||
continue
|
||||
LOGGER.debug(
|
||||
"部门 %s 数据请求无结果:%s",
|
||||
self.settings.code,
|
||||
filtered_requests,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
decision_data.pop("data_requests", None)
|
||||
|
||||
final_message = message
|
||||
break
|
||||
|
||||
if final_message is None:
|
||||
LOGGER.warning(
|
||||
"部门 %s 函数调用达到轮次上限仍未返回文本,使用最后一次消息",
|
||||
self.settings.code,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
final_message = message
|
||||
|
||||
mutable_context.raw["supplement_transcript"] = list(transcript)
|
||||
|
||||
content_text = _extract_message_content(final_message)
|
||||
decision_data = _parse_department_response(content_text)
|
||||
|
||||
action = _normalize_action(decision_data.get("action"))
|
||||
confidence = _clamp_float(decision_data.get("confidence"), default=0.5)
|
||||
summary = decision_data.get("summary") or decision_data.get("reason") or ""
|
||||
@ -170,7 +191,7 @@ class DepartmentAgent:
|
||||
summary=summary or "未提供摘要",
|
||||
signals=[str(sig) for sig in signals if sig],
|
||||
risks=[str(risk) for risk in risks if risk],
|
||||
raw_response="\n\n".join(transcript) if transcript else response,
|
||||
raw_response=content_text,
|
||||
supplements=list(mutable_context.raw.get("supplement_data", [])),
|
||||
dialogue=list(transcript),
|
||||
)
|
||||
@ -208,16 +229,16 @@ class DepartmentAgent:
|
||||
field = req.field.strip()
|
||||
if not field:
|
||||
continue
|
||||
resolved = self._broker.resolve_field(field)
|
||||
if not resolved:
|
||||
lines.append(f"- {field}: 字段不存在或不可用")
|
||||
continue
|
||||
if req.window <= 1:
|
||||
if field not in latest_fields:
|
||||
latest_fields.append(field)
|
||||
delivered.add((field, 1))
|
||||
continue
|
||||
parsed = parse_field_path(field)
|
||||
if not parsed:
|
||||
lines.append(f"- {field}: 字段不合法,已忽略")
|
||||
continue
|
||||
series_requests.append((req, parsed))
|
||||
series_requests.append((req, resolved))
|
||||
delivered.add((field, req.window))
|
||||
|
||||
if latest_fields:
|
||||
@ -251,11 +272,186 @@ class DepartmentAgent:
|
||||
lines.append(
|
||||
f"- {req.field} (window={req.window}): (数据缺失)"
|
||||
)
|
||||
payload.append({"field": req.field, "window": req.window, "values": series})
|
||||
payload.append(
|
||||
{
|
||||
"field": req.field,
|
||||
"window": req.window,
|
||||
"values": [
|
||||
{"trade_date": dt, "value": val}
|
||||
for dt, val in series
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return lines, payload, delivered
|
||||
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
context: DepartmentContext,
|
||||
call: Mapping[str, Any],
|
||||
delivered_requests: set[Tuple[str, int]],
|
||||
round_idx: int,
|
||||
) -> Tuple[Dict[str, Any], set[Tuple[str, int]]]:
|
||||
function_block = call.get("function") or {}
|
||||
name = function_block.get("name") or ""
|
||||
if name != "fetch_data":
|
||||
LOGGER.warning(
|
||||
"部门 %s 收到未知工具调用:%s",
|
||||
self.settings.code,
|
||||
name,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"未知工具 {name}",
|
||||
}, set()
|
||||
|
||||
args = _parse_tool_arguments(function_block.get("arguments"))
|
||||
raw_requests = args.get("requests") or []
|
||||
requests: List[DataRequest] = []
|
||||
skipped: List[str] = []
|
||||
for item in raw_requests:
|
||||
field = str(item.get("field", "")).strip()
|
||||
if not field:
|
||||
continue
|
||||
try:
|
||||
window = int(item.get("window", 1))
|
||||
except (TypeError, ValueError):
|
||||
window = 1
|
||||
window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120)))
|
||||
key = (field, window)
|
||||
if key in delivered_requests:
|
||||
skipped.append(field)
|
||||
continue
|
||||
requests.append(DataRequest(field=field, window=window))
|
||||
|
||||
if not requests:
|
||||
return {
|
||||
"status": "ok",
|
||||
"round": round_idx + 1,
|
||||
"results": [],
|
||||
"skipped": skipped,
|
||||
}, set()
|
||||
|
||||
lines, payload, delivered = self._fulfill_data_requests(context, requests)
|
||||
if payload:
|
||||
context.raw.setdefault("supplement_data", []).extend(payload)
|
||||
context.raw.setdefault("supplement_rounds", []).append(
|
||||
{
|
||||
"round": round_idx + 1,
|
||||
"requests": [req.__dict__ for req in requests],
|
||||
"data": payload,
|
||||
"notes": lines,
|
||||
}
|
||||
)
|
||||
if lines:
|
||||
context.raw.setdefault("supplement_notes", []).append(
|
||||
{
|
||||
"round": round_idx + 1,
|
||||
"lines": lines,
|
||||
}
|
||||
)
|
||||
|
||||
response_payload = {
|
||||
"status": "ok",
|
||||
"round": round_idx + 1,
|
||||
"results": payload,
|
||||
"notes": lines,
|
||||
"skipped": skipped,
|
||||
}
|
||||
return response_payload, delivered
|
||||
|
||||
|
||||
def _build_tools(self) -> List[Dict[str, Any]]:
|
||||
max_window = getattr(self._broker, "MAX_WINDOW", 120)
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "fetch_data",
|
||||
"description": (
|
||||
"根据字段请求数据库中的最新值或时间序列。支持 table.column 格式的字段,"
|
||||
"window 表示希望返回的最近数据点数量。"
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"requests": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"type": "string",
|
||||
"description": "数据字段,格式为 table.column",
|
||||
},
|
||||
"window": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": max_window,
|
||||
"description": "返回最近多少个数据点,默认为 1",
|
||||
},
|
||||
},
|
||||
"required": ["field"],
|
||||
},
|
||||
"minItems": 1,
|
||||
}
|
||||
},
|
||||
"required": ["requests"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
def _analyze_legacy(
|
||||
self,
|
||||
context: DepartmentContext,
|
||||
system_prompt: str,
|
||||
) -> DepartmentDecision:
|
||||
prompt = department_prompt(self.settings, context)
|
||||
llm_cfg = self._get_llm_config()
|
||||
try:
|
||||
response = run_llm_with_config(llm_cfg, prompt, system=system_prompt)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.exception(
|
||||
"部门 %s 调用 LLM 失败:%s",
|
||||
self.settings.code,
|
||||
exc,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return DepartmentDecision(
|
||||
department=self.settings.code,
|
||||
action=AgentAction.HOLD,
|
||||
confidence=0.0,
|
||||
summary=f"LLM 调用失败:{exc}",
|
||||
raw_response=str(exc),
|
||||
)
|
||||
|
||||
context.raw["supplement_transcript"] = [response]
|
||||
decision_data = _parse_department_response(response)
|
||||
action = _normalize_action(decision_data.get("action"))
|
||||
confidence = _clamp_float(decision_data.get("confidence"), default=0.5)
|
||||
summary = decision_data.get("summary") or decision_data.get("reason") or ""
|
||||
signals = decision_data.get("signals") or decision_data.get("rationale") or []
|
||||
if isinstance(signals, str):
|
||||
signals = [signals]
|
||||
risks = decision_data.get("risks") or decision_data.get("warnings") or []
|
||||
if isinstance(risks, str):
|
||||
risks = [risks]
|
||||
|
||||
decision = DepartmentDecision(
|
||||
department=self.settings.code,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
summary=summary or "未提供摘要",
|
||||
signals=[str(sig) for sig in signals if sig],
|
||||
risks=[str(risk) for risk in risks if risk],
|
||||
raw_response=response,
|
||||
supplements=list(context.raw.get("supplement_data", [])),
|
||||
dialogue=[response],
|
||||
)
|
||||
return decision
|
||||
def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
|
||||
if not isinstance(context.features, dict):
|
||||
context.features = dict(context.features or {})
|
||||
@ -302,6 +498,54 @@ def _parse_data_requests(payload: Mapping[str, Any]) -> List[DataRequest]:
|
||||
return requests
|
||||
|
||||
|
||||
def _parse_tool_arguments(payload: Any) -> Dict[str, Any]:
|
||||
if isinstance(payload, dict):
|
||||
return dict(payload)
|
||||
if isinstance(payload, str):
|
||||
try:
|
||||
data = json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
LOGGER.debug("工具参数解析失败:%s", payload, extra=LOG_EXTRA)
|
||||
return {}
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
return {}
|
||||
|
||||
|
||||
def _message_to_text(message: Mapping[str, Any]) -> str:
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
parts: List[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, Mapping) and "text" in item:
|
||||
parts.append(str(item.get("text", "")))
|
||||
else:
|
||||
parts.append(str(item))
|
||||
if parts:
|
||||
return "".join(parts)
|
||||
elif isinstance(content, str) and content.strip():
|
||||
return content
|
||||
tool_calls = message.get("tool_calls")
|
||||
if tool_calls:
|
||||
return json.dumps({"tool_calls": tool_calls}, ensure_ascii=False)
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_message_content(message: Mapping[str, Any]) -> str:
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
texts = [
|
||||
str(item.get("text", ""))
|
||||
for item in content
|
||||
if isinstance(item, Mapping) and "text" in item
|
||||
]
|
||||
if texts:
|
||||
return "".join(texts)
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return json.dumps(message, ensure_ascii=False)
|
||||
|
||||
|
||||
class DepartmentManager:
|
||||
"""Orchestrates all departments defined in configuration."""
|
||||
|
||||
|
||||
@ -11,6 +11,8 @@ import requests
|
||||
from app.utils.config import (
|
||||
DEFAULT_LLM_BASE_URLS,
|
||||
DEFAULT_LLM_MODELS,
|
||||
DEFAULT_LLM_TEMPERATURES,
|
||||
DEFAULT_LLM_TIMEOUTS,
|
||||
LLMConfig,
|
||||
LLMEndpoint,
|
||||
get_config,
|
||||
@ -34,8 +36,8 @@ def _default_model(provider: str) -> str:
|
||||
return DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"])
|
||||
|
||||
|
||||
def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, str]]:
|
||||
messages: List[Dict[str, str]] = []
|
||||
def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, object]]:
|
||||
messages: List[Dict[str, object]] = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
@ -70,38 +72,39 @@ def _request_ollama(
|
||||
return str(content)
|
||||
|
||||
|
||||
def _request_openai(
|
||||
model: str,
|
||||
prompt: str,
|
||||
def _request_openai_chat(
|
||||
*,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
model: str,
|
||||
messages: List[Dict[str, object]],
|
||||
temperature: float,
|
||||
timeout: float,
|
||||
system: Optional[str],
|
||||
) -> str:
|
||||
tools: Optional[List[Dict[str, object]]] = None,
|
||||
tool_choice: Optional[object] = None,
|
||||
) -> Dict[str, object]:
|
||||
url = f"{base_url.rstrip('/')}/v1/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
payload: Dict[str, object] = {
|
||||
"model": model,
|
||||
"messages": _build_messages(prompt, system),
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
if tool_choice is not None:
|
||||
payload["tool_choice"] = tool_choice
|
||||
LOGGER.debug("调用 OpenAI 兼容接口: %s %s", model, url, extra=LOG_EXTRA)
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=timeout)
|
||||
if response.status_code != 200:
|
||||
raise LLMError(f"OpenAI API 调用失败: {response.status_code} {response.text}")
|
||||
data = response.json()
|
||||
try:
|
||||
return data["choices"][0]["message"]["content"].strip()
|
||||
except (KeyError, IndexError) as exc:
|
||||
raise LLMError(f"OpenAI 响应解析失败: {json.dumps(data, ensure_ascii=False)}") from exc
|
||||
return response.json()
|
||||
|
||||
|
||||
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
|
||||
def _prepare_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
|
||||
cfg = get_config()
|
||||
provider_key = (endpoint.provider or "ollama").lower()
|
||||
provider_cfg = cfg.llm_providers.get(provider_key)
|
||||
@ -128,14 +131,30 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) ->
|
||||
else:
|
||||
base_url = base_url or _default_base_url(provider_key)
|
||||
model = model or _default_model(provider_key)
|
||||
|
||||
if temperature is None:
|
||||
temperature = DEFAULT_LLM_TEMPERATURES.get(provider_key, 0.2)
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_LLM_TIMEOUTS.get(provider_key, 30.0)
|
||||
mode = "ollama" if provider_key == "ollama" else "openai"
|
||||
|
||||
temperature = max(0.0, min(float(temperature), 2.0))
|
||||
timeout = max(5.0, float(timeout))
|
||||
return {
|
||||
"provider_key": provider_key,
|
||||
"mode": mode,
|
||||
"base_url": base_url,
|
||||
"api_key": api_key,
|
||||
"model": model,
|
||||
"temperature": max(0.0, min(float(temperature), 2.0)),
|
||||
"timeout": max(5.0, float(timeout)),
|
||||
"prompt_template": prompt_template,
|
||||
}
|
||||
|
||||
|
||||
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
|
||||
resolved = _prepare_endpoint(endpoint)
|
||||
provider_key = resolved["provider_key"]
|
||||
mode = resolved["mode"]
|
||||
prompt_template = resolved["prompt_template"]
|
||||
|
||||
if prompt_template:
|
||||
try:
|
||||
@ -143,6 +162,40 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) ->
|
||||
except Exception: # noqa: BLE001
|
||||
LOGGER.warning("Prompt 模板格式化失败,使用原始 prompt", extra=LOG_EXTRA)
|
||||
|
||||
messages = _build_messages(prompt, system)
|
||||
response = call_endpoint_with_messages(
|
||||
endpoint,
|
||||
messages,
|
||||
tools=None,
|
||||
)
|
||||
if mode == "ollama":
|
||||
message = response.get("message") or {}
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, list):
|
||||
return "".join(chunk.get("text", "") or chunk.get("content", "") for chunk in content)
|
||||
return str(content)
|
||||
try:
|
||||
return response["choices"][0]["message"]["content"].strip()
|
||||
except (KeyError, IndexError) as exc:
|
||||
raise LLMError(f"OpenAI 响应解析失败: {json.dumps(response, ensure_ascii=False)}") from exc
|
||||
|
||||
|
||||
def call_endpoint_with_messages(
|
||||
endpoint: LLMEndpoint,
|
||||
messages: List[Dict[str, object]],
|
||||
*,
|
||||
tools: Optional[List[Dict[str, object]]] = None,
|
||||
tool_choice: Optional[object] = None,
|
||||
) -> Dict[str, object]:
|
||||
resolved = _prepare_endpoint(endpoint)
|
||||
provider_key = resolved["provider_key"]
|
||||
mode = resolved["mode"]
|
||||
base_url = resolved["base_url"]
|
||||
model = resolved["model"]
|
||||
temperature = resolved["temperature"]
|
||||
timeout = resolved["timeout"]
|
||||
api_key = resolved["api_key"]
|
||||
|
||||
LOGGER.info(
|
||||
"触发 LLM 请求:provider=%s model=%s base=%s",
|
||||
provider_key,
|
||||
@ -151,28 +204,36 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) ->
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
|
||||
if mode != "ollama":
|
||||
if mode == "ollama":
|
||||
if tools:
|
||||
raise LLMError("当前 provider 不支持函数调用/工具模式")
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {"temperature": temperature},
|
||||
}
|
||||
response = requests.post(
|
||||
f"{base_url.rstrip('/')}/api/chat",
|
||||
json=payload,
|
||||
timeout=timeout,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}")
|
||||
return response.json()
|
||||
|
||||
if not api_key:
|
||||
raise LLMError(f"缺少 {provider_key} API Key (model={model})")
|
||||
return _request_openai(
|
||||
model,
|
||||
prompt,
|
||||
return _request_openai_chat(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
timeout=timeout,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
if base_url:
|
||||
return _request_ollama(
|
||||
model,
|
||||
prompt,
|
||||
base_url=base_url,
|
||||
temperature=temperature,
|
||||
timeout=timeout,
|
||||
system=system,
|
||||
)
|
||||
raise LLMError(f"不支持的 LLM provider: {endpoint.provider}")
|
||||
|
||||
|
||||
def _normalize_response(text: str) -> str:
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, List, Sequence, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
from .db import db_session
|
||||
from .logging import get_logger
|
||||
@ -42,6 +42,29 @@ def parse_field_path(path: str) -> Tuple[str, str] | None:
|
||||
class DataBroker:
|
||||
"""Lightweight data access helper for agent/LLM consumption."""
|
||||
|
||||
FIELD_ALIASES: Dict[str, Dict[str, str]] = {
|
||||
"daily": {
|
||||
"volume": "vol",
|
||||
"vol": "vol",
|
||||
"turnover": "amount",
|
||||
},
|
||||
"daily_basic": {
|
||||
"turnover": "turnover_rate",
|
||||
"turnover_rate": "turnover_rate",
|
||||
"turnover_rate_f": "turnover_rate_f",
|
||||
"volume_ratio": "volume_ratio",
|
||||
"pe": "pe",
|
||||
"pb": "pb",
|
||||
"ps": "ps",
|
||||
"ps_ttm": "ps_ttm",
|
||||
},
|
||||
"stk_limit": {
|
||||
"up": "up_limit",
|
||||
"down": "down_limit",
|
||||
},
|
||||
}
|
||||
MAX_WINDOW: int = 120
|
||||
|
||||
def fetch_latest(
|
||||
self,
|
||||
ts_code: str,
|
||||
@ -51,16 +74,18 @@ class DataBroker:
|
||||
"""Fetch the latest value (<= trade_date) for each requested field."""
|
||||
|
||||
grouped: Dict[str, List[str]] = {}
|
||||
field_map: Dict[Tuple[str, str], List[str]] = {}
|
||||
for item in fields:
|
||||
if not item:
|
||||
continue
|
||||
normalized = _safe_split(str(item))
|
||||
if not normalized:
|
||||
resolved = self.resolve_field(str(item))
|
||||
if not resolved:
|
||||
continue
|
||||
table, column = normalized
|
||||
table, column = resolved
|
||||
grouped.setdefault(table, [])
|
||||
if column not in grouped[table]:
|
||||
grouped[table].append(column)
|
||||
field_map.setdefault((table, column), []).append(str(item))
|
||||
|
||||
if not grouped:
|
||||
return {}
|
||||
@ -91,8 +116,8 @@ class DataBroker:
|
||||
value = row[column]
|
||||
if value is None:
|
||||
continue
|
||||
key = f"{table}.{column}"
|
||||
results[key] = float(value)
|
||||
for original in field_map.get((table, column), [f"{table}.{column}"]):
|
||||
results[original] = float(value)
|
||||
return results
|
||||
|
||||
def fetch_series(
|
||||
@ -107,10 +132,19 @@ class DataBroker:
|
||||
|
||||
if window <= 0:
|
||||
return []
|
||||
if not (_is_safe_identifier(table) and _is_safe_identifier(column)):
|
||||
window = min(window, self.MAX_WINDOW)
|
||||
resolved_field = self.resolve_field(f"{table}.{column}")
|
||||
if not resolved_field:
|
||||
LOGGER.debug(
|
||||
"时间序列字段不存在 table=%s column=%s",
|
||||
table,
|
||||
column,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return []
|
||||
table, resolved = resolved_field
|
||||
query = (
|
||||
f"SELECT trade_date, {column} FROM {table} "
|
||||
f"SELECT trade_date, {resolved} FROM {table} "
|
||||
"WHERE ts_code = ? AND trade_date <= ? "
|
||||
"ORDER BY trade_date DESC LIMIT ?"
|
||||
)
|
||||
@ -128,7 +162,7 @@ class DataBroker:
|
||||
return []
|
||||
series: List[Tuple[str, float]] = []
|
||||
for row in rows:
|
||||
value = row[column]
|
||||
value = row[resolved]
|
||||
if value is None:
|
||||
continue
|
||||
series.append((row["trade_date"], float(value)))
|
||||
@ -163,3 +197,57 @@ class DataBroker:
|
||||
)
|
||||
return False
|
||||
return row is not None
|
||||
|
||||
def resolve_field(self, field: str) -> Optional[Tuple[str, str]]:
|
||||
normalized = _safe_split(field)
|
||||
if not normalized:
|
||||
return None
|
||||
table, column = normalized
|
||||
resolved = self._resolve_column(table, column)
|
||||
if not resolved:
|
||||
LOGGER.debug(
|
||||
"字段不存在 table=%s column=%s",
|
||||
table,
|
||||
column,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
return None
|
||||
return table, resolved
|
||||
|
||||
def _get_table_columns(self, table: str) -> Optional[set[str]]:
|
||||
if not _is_safe_identifier(table):
|
||||
return None
|
||||
cache = getattr(self, "_column_cache", None)
|
||||
if cache is None:
|
||||
cache = {}
|
||||
self._column_cache = cache
|
||||
if table in cache:
|
||||
return cache[table]
|
||||
try:
|
||||
with db_session(read_only=True) as conn:
|
||||
rows = conn.execute(f"PRAGMA table_info({table})").fetchall()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
LOGGER.debug("获取表字段失败 table=%s err=%s", table, exc, extra=LOG_EXTRA)
|
||||
cache[table] = None
|
||||
return None
|
||||
if not rows:
|
||||
cache[table] = None
|
||||
return None
|
||||
columns = {row["name"] for row in rows if row["name"]}
|
||||
cache[table] = columns
|
||||
return columns
|
||||
|
||||
def _resolve_column(self, table: str, column: str) -> Optional[str]:
|
||||
columns = self._get_table_columns(table)
|
||||
if columns is None:
|
||||
return None
|
||||
alias_map = self.FIELD_ALIASES.get(table, {})
|
||||
candidate = alias_map.get(column, column)
|
||||
if candidate in columns:
|
||||
return candidate
|
||||
# Try lower-case or fallback alias normalization
|
||||
lowered = candidate.lower()
|
||||
for name in columns:
|
||||
if name.lower() == lowered:
|
||||
return name
|
||||
return None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user