This commit is contained in:
sam 2025-09-28 21:22:15 +08:00
parent 0a2742b869
commit ca7b249c2c
4 changed files with 497 additions and 103 deletions

View File

@ -36,9 +36,10 @@
1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description``data_scope``department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。 1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description``data_scope``department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。
2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。 2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。
3. **双阶段 LLM 工作流**:第一阶段让 LLM 输出结构化 `data_requests`,服务端取数后将摘要回填,第二阶段再生成最终行动与解释,形成闭环。 3. **双阶段 LLM 工作流**:第一阶段让 LLM 输出结构化 `data_requests`,服务端取数后将摘要回填,第二阶段再生成最终行动与解释,形成闭环。
4. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。 4. **函数式工具调用**DeepSeek 等 OpenAI 兼容模型已通过 function calling 接口接入 `fetch_data` 工具LLM 按 schema 返回字段/窗口请求,系统使用 `DataBroker` 校验并补数后回传 tool result再继续对话生成最终意见避免字段错误与手写 JSON 校验。
5. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。
目前部门 LLM 已支持通过在 JSON 中返回 `data_requests` 触发追加查询:系统会使用 `DataBroker` 验证字段后补充最近数据窗口,再带着查询结果进入下一轮提示,从而形成“请求→取数→复议”的闭环。 目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型按 schema 声明需要的 `table.column` 字段与窗口,系统使用 `DataBroker` 验证后补齐数据并作为工具响应回传,再带着查询结果进入下一轮提示,从而形成“请求 取数 复议”的闭环。
上述调整可在单个部门先行做 PoC验证闭环能力后再推广至全部角色。 上述调整可在单个部门先行做 PoC验证闭环能力后再推广至全部角色。

View File

@ -6,11 +6,11 @@ from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
from app.agents.base import AgentAction from app.agents.base import AgentAction
from app.llm.client import run_llm_with_config from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError
from app.llm.prompts import department_prompt from app.llm.prompts import department_prompt
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
from app.utils.logging import get_logger from app.utils.logging import get_logger
from app.utils.data_access import DataBroker, parse_field_path from app.utils.data_access import DataBroker
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "department"} LOG_EXTRA = {"stage": "department"}
@ -85,74 +85,95 @@ class DepartmentAgent:
"你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。" "你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。"
) )
llm_cfg = self._get_llm_config() llm_cfg = self._get_llm_config()
supplement_chunks: List[str] = []
if llm_cfg.strategy not in (None, "", "single") or llm_cfg.ensemble:
LOGGER.warning(
"部门 %s 当前配置不支持函数调用模式,回退至传统提示",
self.settings.code,
extra=LOG_EXTRA,
)
return self._analyze_legacy(mutable_context, system_prompt)
tools = self._build_tools()
messages: List[Dict[str, object]] = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append(
{
"role": "user",
"content": department_prompt(self.settings, mutable_context),
}
)
transcript: List[str] = [] transcript: List[str] = []
delivered_requests = { delivered_requests: set[Tuple[str, int]] = {
(field, 1) (field, 1)
for field in (mutable_context.raw.get("scope_values") or {}).keys() for field in (mutable_context.raw.get("scope_values") or {}).keys()
} }
response = "" primary_endpoint = llm_cfg.primary
decision_data: Dict[str, Any] = {} final_message: Optional[Dict[str, Any]] = None
for round_idx in range(self._max_rounds): for round_idx in range(self._max_rounds):
supplement_text = "\n\n".join(chunk for chunk in supplement_chunks if chunk)
prompt = department_prompt(self.settings, mutable_context, supplements=supplement_text)
try: try:
response = run_llm_with_config(llm_cfg, prompt, system=system_prompt) response = call_endpoint_with_messages(
except Exception as exc: # noqa: BLE001 primary_endpoint,
LOGGER.exception( messages,
"部门 %s 调用 LLM 失败:%s", tools=tools,
tool_choice="auto",
)
except LLMError as exc:
LOGGER.warning(
"部门 %s 函数调用失败,回退传统提示:%s",
self.settings.code, self.settings.code,
exc, exc,
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
return DepartmentDecision( return self._analyze_legacy(mutable_context, system_prompt)
department=self.settings.code,
action=AgentAction.HOLD,
confidence=0.0,
summary=f"LLM 调用失败:{exc}",
raw_response=str(exc),
)
transcript.append(response) choice = (response.get("choices") or [{}])[0]
decision_data = _parse_department_response(response) message = choice.get("message", {})
data_requests = _parse_data_requests(decision_data) transcript.append(_message_to_text(message))
filtered_requests = [
req
for req in data_requests
if (req.field, req.window) not in delivered_requests
]
if filtered_requests and round_idx < self._max_rounds - 1: tool_calls = message.get("tool_calls") or []
lines, payload, delivered = self._fulfill_data_requests( if tool_calls:
mutable_context, filtered_requests for call in tool_calls:
) tool_response, delivered = self._handle_tool_call(
if payload: mutable_context,
supplement_chunks.append( call,
f"回合 {round_idx + 1} 追加数据:\n" + "\n".join(lines) delivered_requests,
round_idx,
) )
mutable_context.raw.setdefault("supplement_data", []).extend(payload) transcript.append(
mutable_context.raw.setdefault("supplement_rounds", []).append( json.dumps({"tool_response": tool_response}, ensure_ascii=False)
)
messages.append(
{ {
"round": round_idx + 1, "role": "tool",
"requests": [req.__dict__ for req in filtered_requests], "tool_call_id": call.get("id"),
"data": payload, "name": call.get("function", {}).get("name"),
"content": json.dumps(tool_response, ensure_ascii=False),
} }
) )
delivered_requests.update(delivered) delivered_requests.update(delivered)
decision_data.pop("data_requests", None) continue
continue
LOGGER.debug( final_message = message
"部门 %s 数据请求无结果:%s",
self.settings.code,
filtered_requests,
extra=LOG_EXTRA,
)
decision_data.pop("data_requests", None)
break break
if final_message is None:
LOGGER.warning(
"部门 %s 函数调用达到轮次上限仍未返回文本,使用最后一次消息",
self.settings.code,
extra=LOG_EXTRA,
)
final_message = message
mutable_context.raw["supplement_transcript"] = list(transcript) mutable_context.raw["supplement_transcript"] = list(transcript)
content_text = _extract_message_content(final_message)
decision_data = _parse_department_response(content_text)
action = _normalize_action(decision_data.get("action")) action = _normalize_action(decision_data.get("action"))
confidence = _clamp_float(decision_data.get("confidence"), default=0.5) confidence = _clamp_float(decision_data.get("confidence"), default=0.5)
summary = decision_data.get("summary") or decision_data.get("reason") or "" summary = decision_data.get("summary") or decision_data.get("reason") or ""
@ -170,7 +191,7 @@ class DepartmentAgent:
summary=summary or "未提供摘要", summary=summary or "未提供摘要",
signals=[str(sig) for sig in signals if sig], signals=[str(sig) for sig in signals if sig],
risks=[str(risk) for risk in risks if risk], risks=[str(risk) for risk in risks if risk],
raw_response="\n\n".join(transcript) if transcript else response, raw_response=content_text,
supplements=list(mutable_context.raw.get("supplement_data", [])), supplements=list(mutable_context.raw.get("supplement_data", [])),
dialogue=list(transcript), dialogue=list(transcript),
) )
@ -208,16 +229,16 @@ class DepartmentAgent:
field = req.field.strip() field = req.field.strip()
if not field: if not field:
continue continue
resolved = self._broker.resolve_field(field)
if not resolved:
lines.append(f"- {field}: 字段不存在或不可用")
continue
if req.window <= 1: if req.window <= 1:
if field not in latest_fields: if field not in latest_fields:
latest_fields.append(field) latest_fields.append(field)
delivered.add((field, 1)) delivered.add((field, 1))
continue continue
parsed = parse_field_path(field) series_requests.append((req, resolved))
if not parsed:
lines.append(f"- {field}: 字段不合法,已忽略")
continue
series_requests.append((req, parsed))
delivered.add((field, req.window)) delivered.add((field, req.window))
if latest_fields: if latest_fields:
@ -251,11 +272,186 @@ class DepartmentAgent:
lines.append( lines.append(
f"- {req.field} (window={req.window}): (数据缺失)" f"- {req.field} (window={req.window}): (数据缺失)"
) )
payload.append({"field": req.field, "window": req.window, "values": series}) payload.append(
{
"field": req.field,
"window": req.window,
"values": [
{"trade_date": dt, "value": val}
for dt, val in series
],
}
)
return lines, payload, delivered return lines, payload, delivered
def _handle_tool_call(
self,
context: DepartmentContext,
call: Mapping[str, Any],
delivered_requests: set[Tuple[str, int]],
round_idx: int,
) -> Tuple[Dict[str, Any], set[Tuple[str, int]]]:
function_block = call.get("function") or {}
name = function_block.get("name") or ""
if name != "fetch_data":
LOGGER.warning(
"部门 %s 收到未知工具调用:%s",
self.settings.code,
name,
extra=LOG_EXTRA,
)
return {
"status": "error",
"message": f"未知工具 {name}",
}, set()
args = _parse_tool_arguments(function_block.get("arguments"))
raw_requests = args.get("requests") or []
requests: List[DataRequest] = []
skipped: List[str] = []
for item in raw_requests:
field = str(item.get("field", "")).strip()
if not field:
continue
try:
window = int(item.get("window", 1))
except (TypeError, ValueError):
window = 1
window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120)))
key = (field, window)
if key in delivered_requests:
skipped.append(field)
continue
requests.append(DataRequest(field=field, window=window))
if not requests:
return {
"status": "ok",
"round": round_idx + 1,
"results": [],
"skipped": skipped,
}, set()
lines, payload, delivered = self._fulfill_data_requests(context, requests)
if payload:
context.raw.setdefault("supplement_data", []).extend(payload)
context.raw.setdefault("supplement_rounds", []).append(
{
"round": round_idx + 1,
"requests": [req.__dict__ for req in requests],
"data": payload,
"notes": lines,
}
)
if lines:
context.raw.setdefault("supplement_notes", []).append(
{
"round": round_idx + 1,
"lines": lines,
}
)
response_payload = {
"status": "ok",
"round": round_idx + 1,
"results": payload,
"notes": lines,
"skipped": skipped,
}
return response_payload, delivered
def _build_tools(self) -> List[Dict[str, Any]]:
max_window = getattr(self._broker, "MAX_WINDOW", 120)
return [
{
"type": "function",
"function": {
"name": "fetch_data",
"description": (
"根据字段请求数据库中的最新值或时间序列。支持 table.column 格式的字段,"
"window 表示希望返回的最近数据点数量。"
),
"parameters": {
"type": "object",
"properties": {
"requests": {
"type": "array",
"items": {
"type": "object",
"properties": {
"field": {
"type": "string",
"description": "数据字段,格式为 table.column",
},
"window": {
"type": "integer",
"minimum": 1,
"maximum": max_window,
"description": "返回最近多少个数据点,默认为 1",
},
},
"required": ["field"],
},
"minItems": 1,
}
},
"required": ["requests"],
},
},
}
]
def _analyze_legacy(
self,
context: DepartmentContext,
system_prompt: str,
) -> DepartmentDecision:
prompt = department_prompt(self.settings, context)
llm_cfg = self._get_llm_config()
try:
response = run_llm_with_config(llm_cfg, prompt, system=system_prompt)
except Exception as exc: # noqa: BLE001
LOGGER.exception(
"部门 %s 调用 LLM 失败:%s",
self.settings.code,
exc,
extra=LOG_EXTRA,
)
return DepartmentDecision(
department=self.settings.code,
action=AgentAction.HOLD,
confidence=0.0,
summary=f"LLM 调用失败:{exc}",
raw_response=str(exc),
)
context.raw["supplement_transcript"] = [response]
decision_data = _parse_department_response(response)
action = _normalize_action(decision_data.get("action"))
confidence = _clamp_float(decision_data.get("confidence"), default=0.5)
summary = decision_data.get("summary") or decision_data.get("reason") or ""
signals = decision_data.get("signals") or decision_data.get("rationale") or []
if isinstance(signals, str):
signals = [signals]
risks = decision_data.get("risks") or decision_data.get("warnings") or []
if isinstance(risks, str):
risks = [risks]
decision = DepartmentDecision(
department=self.settings.code,
action=action,
confidence=confidence,
summary=summary or "未提供摘要",
signals=[str(sig) for sig in signals if sig],
risks=[str(risk) for risk in risks if risk],
raw_response=response,
supplements=list(context.raw.get("supplement_data", [])),
dialogue=[response],
)
return decision
def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext: def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
if not isinstance(context.features, dict): if not isinstance(context.features, dict):
context.features = dict(context.features or {}) context.features = dict(context.features or {})
@ -302,6 +498,54 @@ def _parse_data_requests(payload: Mapping[str, Any]) -> List[DataRequest]:
return requests return requests
def _parse_tool_arguments(payload: Any) -> Dict[str, Any]:
if isinstance(payload, dict):
return dict(payload)
if isinstance(payload, str):
try:
data = json.loads(payload)
except json.JSONDecodeError:
LOGGER.debug("工具参数解析失败:%s", payload, extra=LOG_EXTRA)
return {}
if isinstance(data, dict):
return data
return {}
def _message_to_text(message: Mapping[str, Any]) -> str:
content = message.get("content")
if isinstance(content, list):
parts: List[str] = []
for item in content:
if isinstance(item, Mapping) and "text" in item:
parts.append(str(item.get("text", "")))
else:
parts.append(str(item))
if parts:
return "".join(parts)
elif isinstance(content, str) and content.strip():
return content
tool_calls = message.get("tool_calls")
if tool_calls:
return json.dumps({"tool_calls": tool_calls}, ensure_ascii=False)
return ""
def _extract_message_content(message: Mapping[str, Any]) -> str:
content = message.get("content")
if isinstance(content, list):
texts = [
str(item.get("text", ""))
for item in content
if isinstance(item, Mapping) and "text" in item
]
if texts:
return "".join(texts)
if isinstance(content, str):
return content
return json.dumps(message, ensure_ascii=False)
class DepartmentManager: class DepartmentManager:
"""Orchestrates all departments defined in configuration.""" """Orchestrates all departments defined in configuration."""

View File

@ -11,6 +11,8 @@ import requests
from app.utils.config import ( from app.utils.config import (
DEFAULT_LLM_BASE_URLS, DEFAULT_LLM_BASE_URLS,
DEFAULT_LLM_MODELS, DEFAULT_LLM_MODELS,
DEFAULT_LLM_TEMPERATURES,
DEFAULT_LLM_TIMEOUTS,
LLMConfig, LLMConfig,
LLMEndpoint, LLMEndpoint,
get_config, get_config,
@ -34,8 +36,8 @@ def _default_model(provider: str) -> str:
return DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"]) return DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"])
def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, str]]: def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, object]]:
messages: List[Dict[str, str]] = [] messages: List[Dict[str, object]] = []
if system: if system:
messages.append({"role": "system", "content": system}) messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
@ -70,38 +72,39 @@ def _request_ollama(
return str(content) return str(content)
def _request_openai( def _request_openai_chat(
model: str,
prompt: str,
*, *,
base_url: str, base_url: str,
api_key: str, api_key: str,
model: str,
messages: List[Dict[str, object]],
temperature: float, temperature: float,
timeout: float, timeout: float,
system: Optional[str], tools: Optional[List[Dict[str, object]]] = None,
) -> str: tool_choice: Optional[object] = None,
) -> Dict[str, object]:
url = f"{base_url.rstrip('/')}/v1/chat/completions" url = f"{base_url.rstrip('/')}/v1/chat/completions"
headers = { headers = {
"Authorization": f"Bearer {api_key}", "Authorization": f"Bearer {api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
payload = { payload: Dict[str, object] = {
"model": model, "model": model,
"messages": _build_messages(prompt, system), "messages": messages,
"temperature": temperature, "temperature": temperature,
} }
if tools:
payload["tools"] = tools
if tool_choice is not None:
payload["tool_choice"] = tool_choice
LOGGER.debug("调用 OpenAI 兼容接口: %s %s", model, url, extra=LOG_EXTRA) LOGGER.debug("调用 OpenAI 兼容接口: %s %s", model, url, extra=LOG_EXTRA)
response = requests.post(url, headers=headers, json=payload, timeout=timeout) response = requests.post(url, headers=headers, json=payload, timeout=timeout)
if response.status_code != 200: if response.status_code != 200:
raise LLMError(f"OpenAI API 调用失败: {response.status_code} {response.text}") raise LLMError(f"OpenAI API 调用失败: {response.status_code} {response.text}")
data = response.json() return response.json()
try:
return data["choices"][0]["message"]["content"].strip()
except (KeyError, IndexError) as exc:
raise LLMError(f"OpenAI 响应解析失败: {json.dumps(data, ensure_ascii=False)}") from exc
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str: def _prepare_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
cfg = get_config() cfg = get_config()
provider_key = (endpoint.provider or "ollama").lower() provider_key = (endpoint.provider or "ollama").lower()
provider_cfg = cfg.llm_providers.get(provider_key) provider_cfg = cfg.llm_providers.get(provider_key)
@ -128,14 +131,30 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) ->
else: else:
base_url = base_url or _default_base_url(provider_key) base_url = base_url or _default_base_url(provider_key)
model = model or _default_model(provider_key) model = model or _default_model(provider_key)
if temperature is None: if temperature is None:
temperature = DEFAULT_LLM_TEMPERATURES.get(provider_key, 0.2) temperature = DEFAULT_LLM_TEMPERATURES.get(provider_key, 0.2)
if timeout is None: if timeout is None:
timeout = DEFAULT_LLM_TIMEOUTS.get(provider_key, 30.0) timeout = DEFAULT_LLM_TIMEOUTS.get(provider_key, 30.0)
mode = "ollama" if provider_key == "ollama" else "openai" mode = "ollama" if provider_key == "ollama" else "openai"
temperature = max(0.0, min(float(temperature), 2.0)) return {
timeout = max(5.0, float(timeout)) "provider_key": provider_key,
"mode": mode,
"base_url": base_url,
"api_key": api_key,
"model": model,
"temperature": max(0.0, min(float(temperature), 2.0)),
"timeout": max(5.0, float(timeout)),
"prompt_template": prompt_template,
}
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
resolved = _prepare_endpoint(endpoint)
provider_key = resolved["provider_key"]
mode = resolved["mode"]
prompt_template = resolved["prompt_template"]
if prompt_template: if prompt_template:
try: try:
@ -143,6 +162,40 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) ->
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
LOGGER.warning("Prompt 模板格式化失败,使用原始 prompt", extra=LOG_EXTRA) LOGGER.warning("Prompt 模板格式化失败,使用原始 prompt", extra=LOG_EXTRA)
messages = _build_messages(prompt, system)
response = call_endpoint_with_messages(
endpoint,
messages,
tools=None,
)
if mode == "ollama":
message = response.get("message") or {}
content = message.get("content", "")
if isinstance(content, list):
return "".join(chunk.get("text", "") or chunk.get("content", "") for chunk in content)
return str(content)
try:
return response["choices"][0]["message"]["content"].strip()
except (KeyError, IndexError) as exc:
raise LLMError(f"OpenAI 响应解析失败: {json.dumps(response, ensure_ascii=False)}") from exc
def call_endpoint_with_messages(
endpoint: LLMEndpoint,
messages: List[Dict[str, object]],
*,
tools: Optional[List[Dict[str, object]]] = None,
tool_choice: Optional[object] = None,
) -> Dict[str, object]:
resolved = _prepare_endpoint(endpoint)
provider_key = resolved["provider_key"]
mode = resolved["mode"]
base_url = resolved["base_url"]
model = resolved["model"]
temperature = resolved["temperature"]
timeout = resolved["timeout"]
api_key = resolved["api_key"]
LOGGER.info( LOGGER.info(
"触发 LLM 请求provider=%s model=%s base=%s", "触发 LLM 请求provider=%s model=%s base=%s",
provider_key, provider_key,
@ -151,28 +204,36 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) ->
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
if mode != "ollama": if mode == "ollama":
if not api_key: if tools:
raise LLMError(f"缺少 {provider_key} API Key (model={model})") raise LLMError("当前 provider 不支持函数调用/工具模式")
return _request_openai( payload = {
model, "model": model,
prompt, "messages": messages,
base_url=base_url, "stream": False,
api_key=api_key, "options": {"temperature": temperature},
temperature=temperature, }
response = requests.post(
f"{base_url.rstrip('/')}/api/chat",
json=payload,
timeout=timeout, timeout=timeout,
system=system,
) )
if base_url: if response.status_code != 200:
return _request_ollama( raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}")
model, return response.json()
prompt,
base_url=base_url, if not api_key:
temperature=temperature, raise LLMError(f"缺少 {provider_key} API Key (model={model})")
timeout=timeout, return _request_openai_chat(
system=system, base_url=base_url,
) api_key=api_key,
raise LLMError(f"不支持的 LLM provider: {endpoint.provider}") model=model,
messages=messages,
temperature=temperature,
timeout=timeout,
tools=tools,
tool_choice=tool_choice,
)
def _normalize_response(text: str) -> str: def _normalize_response(text: str) -> str:

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Iterable, List, Sequence, Tuple from typing import Dict, Iterable, List, Optional, Sequence, Tuple
from .db import db_session from .db import db_session
from .logging import get_logger from .logging import get_logger
@ -42,6 +42,29 @@ def parse_field_path(path: str) -> Tuple[str, str] | None:
class DataBroker: class DataBroker:
"""Lightweight data access helper for agent/LLM consumption.""" """Lightweight data access helper for agent/LLM consumption."""
FIELD_ALIASES: Dict[str, Dict[str, str]] = {
"daily": {
"volume": "vol",
"vol": "vol",
"turnover": "amount",
},
"daily_basic": {
"turnover": "turnover_rate",
"turnover_rate": "turnover_rate",
"turnover_rate_f": "turnover_rate_f",
"volume_ratio": "volume_ratio",
"pe": "pe",
"pb": "pb",
"ps": "ps",
"ps_ttm": "ps_ttm",
},
"stk_limit": {
"up": "up_limit",
"down": "down_limit",
},
}
MAX_WINDOW: int = 120
def fetch_latest( def fetch_latest(
self, self,
ts_code: str, ts_code: str,
@ -51,16 +74,18 @@ class DataBroker:
"""Fetch the latest value (<= trade_date) for each requested field.""" """Fetch the latest value (<= trade_date) for each requested field."""
grouped: Dict[str, List[str]] = {} grouped: Dict[str, List[str]] = {}
field_map: Dict[Tuple[str, str], List[str]] = {}
for item in fields: for item in fields:
if not item: if not item:
continue continue
normalized = _safe_split(str(item)) resolved = self.resolve_field(str(item))
if not normalized: if not resolved:
continue continue
table, column = normalized table, column = resolved
grouped.setdefault(table, []) grouped.setdefault(table, [])
if column not in grouped[table]: if column not in grouped[table]:
grouped[table].append(column) grouped[table].append(column)
field_map.setdefault((table, column), []).append(str(item))
if not grouped: if not grouped:
return {} return {}
@ -91,8 +116,8 @@ class DataBroker:
value = row[column] value = row[column]
if value is None: if value is None:
continue continue
key = f"{table}.{column}" for original in field_map.get((table, column), [f"{table}.{column}"]):
results[key] = float(value) results[original] = float(value)
return results return results
def fetch_series( def fetch_series(
@ -107,10 +132,19 @@ class DataBroker:
if window <= 0: if window <= 0:
return [] return []
if not (_is_safe_identifier(table) and _is_safe_identifier(column)): window = min(window, self.MAX_WINDOW)
resolved_field = self.resolve_field(f"{table}.{column}")
if not resolved_field:
LOGGER.debug(
"时间序列字段不存在 table=%s column=%s",
table,
column,
extra=LOG_EXTRA,
)
return [] return []
table, resolved = resolved_field
query = ( query = (
f"SELECT trade_date, {column} FROM {table} " f"SELECT trade_date, {resolved} FROM {table} "
"WHERE ts_code = ? AND trade_date <= ? " "WHERE ts_code = ? AND trade_date <= ? "
"ORDER BY trade_date DESC LIMIT ?" "ORDER BY trade_date DESC LIMIT ?"
) )
@ -128,7 +162,7 @@ class DataBroker:
return [] return []
series: List[Tuple[str, float]] = [] series: List[Tuple[str, float]] = []
for row in rows: for row in rows:
value = row[column] value = row[resolved]
if value is None: if value is None:
continue continue
series.append((row["trade_date"], float(value))) series.append((row["trade_date"], float(value)))
@ -163,3 +197,57 @@ class DataBroker:
) )
return False return False
return row is not None return row is not None
def resolve_field(self, field: str) -> Optional[Tuple[str, str]]:
normalized = _safe_split(field)
if not normalized:
return None
table, column = normalized
resolved = self._resolve_column(table, column)
if not resolved:
LOGGER.debug(
"字段不存在 table=%s column=%s",
table,
column,
extra=LOG_EXTRA,
)
return None
return table, resolved
def _get_table_columns(self, table: str) -> Optional[set[str]]:
if not _is_safe_identifier(table):
return None
cache = getattr(self, "_column_cache", None)
if cache is None:
cache = {}
self._column_cache = cache
if table in cache:
return cache[table]
try:
with db_session(read_only=True) as conn:
rows = conn.execute(f"PRAGMA table_info({table})").fetchall()
except Exception as exc: # noqa: BLE001
LOGGER.debug("获取表字段失败 table=%s err=%s", table, exc, extra=LOG_EXTRA)
cache[table] = None
return None
if not rows:
cache[table] = None
return None
columns = {row["name"] for row in rows if row["name"]}
cache[table] = columns
return columns
def _resolve_column(self, table: str, column: str) -> Optional[str]:
columns = self._get_table_columns(table)
if columns is None:
return None
alias_map = self.FIELD_ALIASES.get(table, {})
candidate = alias_map.get(column, column)
if candidate in columns:
return candidate
# Try lower-case or fallback alias normalization
lowered = candidate.lower()
for name in columns:
if name.lower() == lowered:
return name
return None