update
This commit is contained in:
parent
0a2742b869
commit
ca7b249c2c
@ -36,9 +36,10 @@
|
|||||||
1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description` 与 `data_scope`,`department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。
|
1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description` 与 `data_scope`,`department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。
|
||||||
2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。
|
2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。
|
||||||
3. **双阶段 LLM 工作流**:第一阶段让 LLM 输出结构化 `data_requests`,服务端取数后将摘要回填,第二阶段再生成最终行动与解释,形成闭环。
|
3. **双阶段 LLM 工作流**:第一阶段让 LLM 输出结构化 `data_requests`,服务端取数后将摘要回填,第二阶段再生成最终行动与解释,形成闭环。
|
||||||
4. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。
|
4. **函数式工具调用**:DeepSeek 等 OpenAI 兼容模型已通过 function calling 接口接入 `fetch_data` 工具,LLM 按 schema 返回字段/窗口请求,系统使用 `DataBroker` 校验并补数后回传 tool result,再继续对话生成最终意见,避免字段错误与手写 JSON 校验。
|
||||||
|
5. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。
|
||||||
|
|
||||||
目前部门 LLM 已支持通过在 JSON 中返回 `data_requests` 触发追加查询:系统会使用 `DataBroker` 验证字段后补充最近数据窗口,再带着查询结果进入下一轮提示,从而形成“请求→取数→复议”的闭环。
|
目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型按 schema 声明需要的 `table.column` 字段与窗口,系统使用 `DataBroker` 验证后补齐数据并作为工具响应回传,再带着查询结果进入下一轮提示,从而形成“请求 → 取数 → 复议”的闭环。
|
||||||
|
|
||||||
上述调整可在单个部门先行做 PoC,验证闭环能力后再推广至全部角色。
|
上述调整可在单个部门先行做 PoC,验证闭环能力后再推广至全部角色。
|
||||||
|
|
||||||
|
|||||||
@ -6,11 +6,11 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from app.agents.base import AgentAction
|
from app.agents.base import AgentAction
|
||||||
from app.llm.client import run_llm_with_config
|
from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError
|
||||||
from app.llm.prompts import department_prompt
|
from app.llm.prompts import department_prompt
|
||||||
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
|
from app.utils.config import AppConfig, DepartmentSettings, LLMConfig
|
||||||
from app.utils.logging import get_logger
|
from app.utils.logging import get_logger
|
||||||
from app.utils.data_access import DataBroker, parse_field_path
|
from app.utils.data_access import DataBroker
|
||||||
|
|
||||||
LOGGER = get_logger(__name__)
|
LOGGER = get_logger(__name__)
|
||||||
LOG_EXTRA = {"stage": "department"}
|
LOG_EXTRA = {"stage": "department"}
|
||||||
@ -85,74 +85,95 @@ class DepartmentAgent:
|
|||||||
"你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。"
|
"你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。"
|
||||||
)
|
)
|
||||||
llm_cfg = self._get_llm_config()
|
llm_cfg = self._get_llm_config()
|
||||||
supplement_chunks: List[str] = []
|
|
||||||
|
if llm_cfg.strategy not in (None, "", "single") or llm_cfg.ensemble:
|
||||||
|
LOGGER.warning(
|
||||||
|
"部门 %s 当前配置不支持函数调用模式,回退至传统提示",
|
||||||
|
self.settings.code,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return self._analyze_legacy(mutable_context, system_prompt)
|
||||||
|
|
||||||
|
tools = self._build_tools()
|
||||||
|
messages: List[Dict[str, object]] = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": department_prompt(self.settings, mutable_context),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
transcript: List[str] = []
|
transcript: List[str] = []
|
||||||
delivered_requests = {
|
delivered_requests: set[Tuple[str, int]] = {
|
||||||
(field, 1)
|
(field, 1)
|
||||||
for field in (mutable_context.raw.get("scope_values") or {}).keys()
|
for field in (mutable_context.raw.get("scope_values") or {}).keys()
|
||||||
}
|
}
|
||||||
|
|
||||||
response = ""
|
primary_endpoint = llm_cfg.primary
|
||||||
decision_data: Dict[str, Any] = {}
|
final_message: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
for round_idx in range(self._max_rounds):
|
for round_idx in range(self._max_rounds):
|
||||||
supplement_text = "\n\n".join(chunk for chunk in supplement_chunks if chunk)
|
|
||||||
prompt = department_prompt(self.settings, mutable_context, supplements=supplement_text)
|
|
||||||
try:
|
try:
|
||||||
response = run_llm_with_config(llm_cfg, prompt, system=system_prompt)
|
response = call_endpoint_with_messages(
|
||||||
except Exception as exc: # noqa: BLE001
|
primary_endpoint,
|
||||||
LOGGER.exception(
|
messages,
|
||||||
"部门 %s 调用 LLM 失败:%s",
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
)
|
||||||
|
except LLMError as exc:
|
||||||
|
LOGGER.warning(
|
||||||
|
"部门 %s 函数调用失败,回退传统提示:%s",
|
||||||
self.settings.code,
|
self.settings.code,
|
||||||
exc,
|
exc,
|
||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
return DepartmentDecision(
|
return self._analyze_legacy(mutable_context, system_prompt)
|
||||||
department=self.settings.code,
|
|
||||||
action=AgentAction.HOLD,
|
|
||||||
confidence=0.0,
|
|
||||||
summary=f"LLM 调用失败:{exc}",
|
|
||||||
raw_response=str(exc),
|
|
||||||
)
|
|
||||||
|
|
||||||
transcript.append(response)
|
choice = (response.get("choices") or [{}])[0]
|
||||||
decision_data = _parse_department_response(response)
|
message = choice.get("message", {})
|
||||||
data_requests = _parse_data_requests(decision_data)
|
transcript.append(_message_to_text(message))
|
||||||
filtered_requests = [
|
|
||||||
req
|
|
||||||
for req in data_requests
|
|
||||||
if (req.field, req.window) not in delivered_requests
|
|
||||||
]
|
|
||||||
|
|
||||||
if filtered_requests and round_idx < self._max_rounds - 1:
|
tool_calls = message.get("tool_calls") or []
|
||||||
lines, payload, delivered = self._fulfill_data_requests(
|
if tool_calls:
|
||||||
mutable_context, filtered_requests
|
for call in tool_calls:
|
||||||
|
tool_response, delivered = self._handle_tool_call(
|
||||||
|
mutable_context,
|
||||||
|
call,
|
||||||
|
delivered_requests,
|
||||||
|
round_idx,
|
||||||
)
|
)
|
||||||
if payload:
|
transcript.append(
|
||||||
supplement_chunks.append(
|
json.dumps({"tool_response": tool_response}, ensure_ascii=False)
|
||||||
f"回合 {round_idx + 1} 追加数据:\n" + "\n".join(lines)
|
|
||||||
)
|
)
|
||||||
mutable_context.raw.setdefault("supplement_data", []).extend(payload)
|
messages.append(
|
||||||
mutable_context.raw.setdefault("supplement_rounds", []).append(
|
|
||||||
{
|
{
|
||||||
"round": round_idx + 1,
|
"role": "tool",
|
||||||
"requests": [req.__dict__ for req in filtered_requests],
|
"tool_call_id": call.get("id"),
|
||||||
"data": payload,
|
"name": call.get("function", {}).get("name"),
|
||||||
|
"content": json.dumps(tool_response, ensure_ascii=False),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
delivered_requests.update(delivered)
|
delivered_requests.update(delivered)
|
||||||
decision_data.pop("data_requests", None)
|
|
||||||
continue
|
continue
|
||||||
LOGGER.debug(
|
|
||||||
"部门 %s 数据请求无结果:%s",
|
final_message = message
|
||||||
self.settings.code,
|
|
||||||
filtered_requests,
|
|
||||||
extra=LOG_EXTRA,
|
|
||||||
)
|
|
||||||
decision_data.pop("data_requests", None)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if final_message is None:
|
||||||
|
LOGGER.warning(
|
||||||
|
"部门 %s 函数调用达到轮次上限仍未返回文本,使用最后一次消息",
|
||||||
|
self.settings.code,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
final_message = message
|
||||||
|
|
||||||
mutable_context.raw["supplement_transcript"] = list(transcript)
|
mutable_context.raw["supplement_transcript"] = list(transcript)
|
||||||
|
|
||||||
|
content_text = _extract_message_content(final_message)
|
||||||
|
decision_data = _parse_department_response(content_text)
|
||||||
|
|
||||||
action = _normalize_action(decision_data.get("action"))
|
action = _normalize_action(decision_data.get("action"))
|
||||||
confidence = _clamp_float(decision_data.get("confidence"), default=0.5)
|
confidence = _clamp_float(decision_data.get("confidence"), default=0.5)
|
||||||
summary = decision_data.get("summary") or decision_data.get("reason") or ""
|
summary = decision_data.get("summary") or decision_data.get("reason") or ""
|
||||||
@ -170,7 +191,7 @@ class DepartmentAgent:
|
|||||||
summary=summary or "未提供摘要",
|
summary=summary or "未提供摘要",
|
||||||
signals=[str(sig) for sig in signals if sig],
|
signals=[str(sig) for sig in signals if sig],
|
||||||
risks=[str(risk) for risk in risks if risk],
|
risks=[str(risk) for risk in risks if risk],
|
||||||
raw_response="\n\n".join(transcript) if transcript else response,
|
raw_response=content_text,
|
||||||
supplements=list(mutable_context.raw.get("supplement_data", [])),
|
supplements=list(mutable_context.raw.get("supplement_data", [])),
|
||||||
dialogue=list(transcript),
|
dialogue=list(transcript),
|
||||||
)
|
)
|
||||||
@ -208,16 +229,16 @@ class DepartmentAgent:
|
|||||||
field = req.field.strip()
|
field = req.field.strip()
|
||||||
if not field:
|
if not field:
|
||||||
continue
|
continue
|
||||||
|
resolved = self._broker.resolve_field(field)
|
||||||
|
if not resolved:
|
||||||
|
lines.append(f"- {field}: 字段不存在或不可用")
|
||||||
|
continue
|
||||||
if req.window <= 1:
|
if req.window <= 1:
|
||||||
if field not in latest_fields:
|
if field not in latest_fields:
|
||||||
latest_fields.append(field)
|
latest_fields.append(field)
|
||||||
delivered.add((field, 1))
|
delivered.add((field, 1))
|
||||||
continue
|
continue
|
||||||
parsed = parse_field_path(field)
|
series_requests.append((req, resolved))
|
||||||
if not parsed:
|
|
||||||
lines.append(f"- {field}: 字段不合法,已忽略")
|
|
||||||
continue
|
|
||||||
series_requests.append((req, parsed))
|
|
||||||
delivered.add((field, req.window))
|
delivered.add((field, req.window))
|
||||||
|
|
||||||
if latest_fields:
|
if latest_fields:
|
||||||
@ -251,11 +272,186 @@ class DepartmentAgent:
|
|||||||
lines.append(
|
lines.append(
|
||||||
f"- {req.field} (window={req.window}): (数据缺失)"
|
f"- {req.field} (window={req.window}): (数据缺失)"
|
||||||
)
|
)
|
||||||
payload.append({"field": req.field, "window": req.window, "values": series})
|
payload.append(
|
||||||
|
{
|
||||||
|
"field": req.field,
|
||||||
|
"window": req.window,
|
||||||
|
"values": [
|
||||||
|
{"trade_date": dt, "value": val}
|
||||||
|
for dt, val in series
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return lines, payload, delivered
|
return lines, payload, delivered
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_tool_call(
|
||||||
|
self,
|
||||||
|
context: DepartmentContext,
|
||||||
|
call: Mapping[str, Any],
|
||||||
|
delivered_requests: set[Tuple[str, int]],
|
||||||
|
round_idx: int,
|
||||||
|
) -> Tuple[Dict[str, Any], set[Tuple[str, int]]]:
|
||||||
|
function_block = call.get("function") or {}
|
||||||
|
name = function_block.get("name") or ""
|
||||||
|
if name != "fetch_data":
|
||||||
|
LOGGER.warning(
|
||||||
|
"部门 %s 收到未知工具调用:%s",
|
||||||
|
self.settings.code,
|
||||||
|
name,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"未知工具 {name}",
|
||||||
|
}, set()
|
||||||
|
|
||||||
|
args = _parse_tool_arguments(function_block.get("arguments"))
|
||||||
|
raw_requests = args.get("requests") or []
|
||||||
|
requests: List[DataRequest] = []
|
||||||
|
skipped: List[str] = []
|
||||||
|
for item in raw_requests:
|
||||||
|
field = str(item.get("field", "")).strip()
|
||||||
|
if not field:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
window = int(item.get("window", 1))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
window = 1
|
||||||
|
window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120)))
|
||||||
|
key = (field, window)
|
||||||
|
if key in delivered_requests:
|
||||||
|
skipped.append(field)
|
||||||
|
continue
|
||||||
|
requests.append(DataRequest(field=field, window=window))
|
||||||
|
|
||||||
|
if not requests:
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"round": round_idx + 1,
|
||||||
|
"results": [],
|
||||||
|
"skipped": skipped,
|
||||||
|
}, set()
|
||||||
|
|
||||||
|
lines, payload, delivered = self._fulfill_data_requests(context, requests)
|
||||||
|
if payload:
|
||||||
|
context.raw.setdefault("supplement_data", []).extend(payload)
|
||||||
|
context.raw.setdefault("supplement_rounds", []).append(
|
||||||
|
{
|
||||||
|
"round": round_idx + 1,
|
||||||
|
"requests": [req.__dict__ for req in requests],
|
||||||
|
"data": payload,
|
||||||
|
"notes": lines,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if lines:
|
||||||
|
context.raw.setdefault("supplement_notes", []).append(
|
||||||
|
{
|
||||||
|
"round": round_idx + 1,
|
||||||
|
"lines": lines,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response_payload = {
|
||||||
|
"status": "ok",
|
||||||
|
"round": round_idx + 1,
|
||||||
|
"results": payload,
|
||||||
|
"notes": lines,
|
||||||
|
"skipped": skipped,
|
||||||
|
}
|
||||||
|
return response_payload, delivered
|
||||||
|
|
||||||
|
|
||||||
|
def _build_tools(self) -> List[Dict[str, Any]]:
|
||||||
|
max_window = getattr(self._broker, "MAX_WINDOW", 120)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "fetch_data",
|
||||||
|
"description": (
|
||||||
|
"根据字段请求数据库中的最新值或时间序列。支持 table.column 格式的字段,"
|
||||||
|
"window 表示希望返回的最近数据点数量。"
|
||||||
|
),
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"requests": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"field": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "数据字段,格式为 table.column",
|
||||||
|
},
|
||||||
|
"window": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": max_window,
|
||||||
|
"description": "返回最近多少个数据点,默认为 1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["field"],
|
||||||
|
},
|
||||||
|
"minItems": 1,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["requests"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
def _analyze_legacy(
|
||||||
|
self,
|
||||||
|
context: DepartmentContext,
|
||||||
|
system_prompt: str,
|
||||||
|
) -> DepartmentDecision:
|
||||||
|
prompt = department_prompt(self.settings, context)
|
||||||
|
llm_cfg = self._get_llm_config()
|
||||||
|
try:
|
||||||
|
response = run_llm_with_config(llm_cfg, prompt, system=system_prompt)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.exception(
|
||||||
|
"部门 %s 调用 LLM 失败:%s",
|
||||||
|
self.settings.code,
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return DepartmentDecision(
|
||||||
|
department=self.settings.code,
|
||||||
|
action=AgentAction.HOLD,
|
||||||
|
confidence=0.0,
|
||||||
|
summary=f"LLM 调用失败:{exc}",
|
||||||
|
raw_response=str(exc),
|
||||||
|
)
|
||||||
|
|
||||||
|
context.raw["supplement_transcript"] = [response]
|
||||||
|
decision_data = _parse_department_response(response)
|
||||||
|
action = _normalize_action(decision_data.get("action"))
|
||||||
|
confidence = _clamp_float(decision_data.get("confidence"), default=0.5)
|
||||||
|
summary = decision_data.get("summary") or decision_data.get("reason") or ""
|
||||||
|
signals = decision_data.get("signals") or decision_data.get("rationale") or []
|
||||||
|
if isinstance(signals, str):
|
||||||
|
signals = [signals]
|
||||||
|
risks = decision_data.get("risks") or decision_data.get("warnings") or []
|
||||||
|
if isinstance(risks, str):
|
||||||
|
risks = [risks]
|
||||||
|
|
||||||
|
decision = DepartmentDecision(
|
||||||
|
department=self.settings.code,
|
||||||
|
action=action,
|
||||||
|
confidence=confidence,
|
||||||
|
summary=summary or "未提供摘要",
|
||||||
|
signals=[str(sig) for sig in signals if sig],
|
||||||
|
risks=[str(risk) for risk in risks if risk],
|
||||||
|
raw_response=response,
|
||||||
|
supplements=list(context.raw.get("supplement_data", [])),
|
||||||
|
dialogue=[response],
|
||||||
|
)
|
||||||
|
return decision
|
||||||
def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
|
def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext:
|
||||||
if not isinstance(context.features, dict):
|
if not isinstance(context.features, dict):
|
||||||
context.features = dict(context.features or {})
|
context.features = dict(context.features or {})
|
||||||
@ -302,6 +498,54 @@ def _parse_data_requests(payload: Mapping[str, Any]) -> List[DataRequest]:
|
|||||||
return requests
|
return requests
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_tool_arguments(payload: Any) -> Dict[str, Any]:
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
return dict(payload)
|
||||||
|
if isinstance(payload, str):
|
||||||
|
try:
|
||||||
|
data = json.loads(payload)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
LOGGER.debug("工具参数解析失败:%s", payload, extra=LOG_EXTRA)
|
||||||
|
return {}
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return data
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _message_to_text(message: Mapping[str, Any]) -> str:
|
||||||
|
content = message.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: List[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, Mapping) and "text" in item:
|
||||||
|
parts.append(str(item.get("text", "")))
|
||||||
|
else:
|
||||||
|
parts.append(str(item))
|
||||||
|
if parts:
|
||||||
|
return "".join(parts)
|
||||||
|
elif isinstance(content, str) and content.strip():
|
||||||
|
return content
|
||||||
|
tool_calls = message.get("tool_calls")
|
||||||
|
if tool_calls:
|
||||||
|
return json.dumps({"tool_calls": tool_calls}, ensure_ascii=False)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_message_content(message: Mapping[str, Any]) -> str:
|
||||||
|
content = message.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
texts = [
|
||||||
|
str(item.get("text", ""))
|
||||||
|
for item in content
|
||||||
|
if isinstance(item, Mapping) and "text" in item
|
||||||
|
]
|
||||||
|
if texts:
|
||||||
|
return "".join(texts)
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
return json.dumps(message, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
class DepartmentManager:
|
class DepartmentManager:
|
||||||
"""Orchestrates all departments defined in configuration."""
|
"""Orchestrates all departments defined in configuration."""
|
||||||
|
|
||||||
|
|||||||
@ -11,6 +11,8 @@ import requests
|
|||||||
from app.utils.config import (
|
from app.utils.config import (
|
||||||
DEFAULT_LLM_BASE_URLS,
|
DEFAULT_LLM_BASE_URLS,
|
||||||
DEFAULT_LLM_MODELS,
|
DEFAULT_LLM_MODELS,
|
||||||
|
DEFAULT_LLM_TEMPERATURES,
|
||||||
|
DEFAULT_LLM_TIMEOUTS,
|
||||||
LLMConfig,
|
LLMConfig,
|
||||||
LLMEndpoint,
|
LLMEndpoint,
|
||||||
get_config,
|
get_config,
|
||||||
@ -34,8 +36,8 @@ def _default_model(provider: str) -> str:
|
|||||||
return DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"])
|
return DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"])
|
||||||
|
|
||||||
|
|
||||||
def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, str]]:
|
def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, object]]:
|
||||||
messages: List[Dict[str, str]] = []
|
messages: List[Dict[str, object]] = []
|
||||||
if system:
|
if system:
|
||||||
messages.append({"role": "system", "content": system})
|
messages.append({"role": "system", "content": system})
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
@ -70,38 +72,39 @@ def _request_ollama(
|
|||||||
return str(content)
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
def _request_openai(
|
def _request_openai_chat(
|
||||||
model: str,
|
|
||||||
prompt: str,
|
|
||||||
*,
|
*,
|
||||||
base_url: str,
|
base_url: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
|
model: str,
|
||||||
|
messages: List[Dict[str, object]],
|
||||||
temperature: float,
|
temperature: float,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
system: Optional[str],
|
tools: Optional[List[Dict[str, object]]] = None,
|
||||||
) -> str:
|
tool_choice: Optional[object] = None,
|
||||||
|
) -> Dict[str, object]:
|
||||||
url = f"{base_url.rstrip('/')}/v1/chat/completions"
|
url = f"{base_url.rstrip('/')}/v1/chat/completions"
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {api_key}",
|
"Authorization": f"Bearer {api_key}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
payload = {
|
payload: Dict[str, object] = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": _build_messages(prompt, system),
|
"messages": messages,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
}
|
}
|
||||||
|
if tools:
|
||||||
|
payload["tools"] = tools
|
||||||
|
if tool_choice is not None:
|
||||||
|
payload["tool_choice"] = tool_choice
|
||||||
LOGGER.debug("调用 OpenAI 兼容接口: %s %s", model, url, extra=LOG_EXTRA)
|
LOGGER.debug("调用 OpenAI 兼容接口: %s %s", model, url, extra=LOG_EXTRA)
|
||||||
response = requests.post(url, headers=headers, json=payload, timeout=timeout)
|
response = requests.post(url, headers=headers, json=payload, timeout=timeout)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise LLMError(f"OpenAI API 调用失败: {response.status_code} {response.text}")
|
raise LLMError(f"OpenAI API 调用失败: {response.status_code} {response.text}")
|
||||||
data = response.json()
|
return response.json()
|
||||||
try:
|
|
||||||
return data["choices"][0]["message"]["content"].strip()
|
|
||||||
except (KeyError, IndexError) as exc:
|
|
||||||
raise LLMError(f"OpenAI 响应解析失败: {json.dumps(data, ensure_ascii=False)}") from exc
|
|
||||||
|
|
||||||
|
|
||||||
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
|
def _prepare_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
|
||||||
cfg = get_config()
|
cfg = get_config()
|
||||||
provider_key = (endpoint.provider or "ollama").lower()
|
provider_key = (endpoint.provider or "ollama").lower()
|
||||||
provider_cfg = cfg.llm_providers.get(provider_key)
|
provider_cfg = cfg.llm_providers.get(provider_key)
|
||||||
@ -128,14 +131,30 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) ->
|
|||||||
else:
|
else:
|
||||||
base_url = base_url or _default_base_url(provider_key)
|
base_url = base_url or _default_base_url(provider_key)
|
||||||
model = model or _default_model(provider_key)
|
model = model or _default_model(provider_key)
|
||||||
|
|
||||||
if temperature is None:
|
if temperature is None:
|
||||||
temperature = DEFAULT_LLM_TEMPERATURES.get(provider_key, 0.2)
|
temperature = DEFAULT_LLM_TEMPERATURES.get(provider_key, 0.2)
|
||||||
if timeout is None:
|
if timeout is None:
|
||||||
timeout = DEFAULT_LLM_TIMEOUTS.get(provider_key, 30.0)
|
timeout = DEFAULT_LLM_TIMEOUTS.get(provider_key, 30.0)
|
||||||
mode = "ollama" if provider_key == "ollama" else "openai"
|
mode = "ollama" if provider_key == "ollama" else "openai"
|
||||||
|
|
||||||
temperature = max(0.0, min(float(temperature), 2.0))
|
return {
|
||||||
timeout = max(5.0, float(timeout))
|
"provider_key": provider_key,
|
||||||
|
"mode": mode,
|
||||||
|
"base_url": base_url,
|
||||||
|
"api_key": api_key,
|
||||||
|
"model": model,
|
||||||
|
"temperature": max(0.0, min(float(temperature), 2.0)),
|
||||||
|
"timeout": max(5.0, float(timeout)),
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
|
||||||
|
resolved = _prepare_endpoint(endpoint)
|
||||||
|
provider_key = resolved["provider_key"]
|
||||||
|
mode = resolved["mode"]
|
||||||
|
prompt_template = resolved["prompt_template"]
|
||||||
|
|
||||||
if prompt_template:
|
if prompt_template:
|
||||||
try:
|
try:
|
||||||
@ -143,6 +162,40 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) ->
|
|||||||
except Exception: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
LOGGER.warning("Prompt 模板格式化失败,使用原始 prompt", extra=LOG_EXTRA)
|
LOGGER.warning("Prompt 模板格式化失败,使用原始 prompt", extra=LOG_EXTRA)
|
||||||
|
|
||||||
|
messages = _build_messages(prompt, system)
|
||||||
|
response = call_endpoint_with_messages(
|
||||||
|
endpoint,
|
||||||
|
messages,
|
||||||
|
tools=None,
|
||||||
|
)
|
||||||
|
if mode == "ollama":
|
||||||
|
message = response.get("message") or {}
|
||||||
|
content = message.get("content", "")
|
||||||
|
if isinstance(content, list):
|
||||||
|
return "".join(chunk.get("text", "") or chunk.get("content", "") for chunk in content)
|
||||||
|
return str(content)
|
||||||
|
try:
|
||||||
|
return response["choices"][0]["message"]["content"].strip()
|
||||||
|
except (KeyError, IndexError) as exc:
|
||||||
|
raise LLMError(f"OpenAI 响应解析失败: {json.dumps(response, ensure_ascii=False)}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def call_endpoint_with_messages(
|
||||||
|
endpoint: LLMEndpoint,
|
||||||
|
messages: List[Dict[str, object]],
|
||||||
|
*,
|
||||||
|
tools: Optional[List[Dict[str, object]]] = None,
|
||||||
|
tool_choice: Optional[object] = None,
|
||||||
|
) -> Dict[str, object]:
|
||||||
|
resolved = _prepare_endpoint(endpoint)
|
||||||
|
provider_key = resolved["provider_key"]
|
||||||
|
mode = resolved["mode"]
|
||||||
|
base_url = resolved["base_url"]
|
||||||
|
model = resolved["model"]
|
||||||
|
temperature = resolved["temperature"]
|
||||||
|
timeout = resolved["timeout"]
|
||||||
|
api_key = resolved["api_key"]
|
||||||
|
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
"触发 LLM 请求:provider=%s model=%s base=%s",
|
"触发 LLM 请求:provider=%s model=%s base=%s",
|
||||||
provider_key,
|
provider_key,
|
||||||
@ -151,28 +204,36 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) ->
|
|||||||
extra=LOG_EXTRA,
|
extra=LOG_EXTRA,
|
||||||
)
|
)
|
||||||
|
|
||||||
if mode != "ollama":
|
if mode == "ollama":
|
||||||
|
if tools:
|
||||||
|
raise LLMError("当前 provider 不支持函数调用/工具模式")
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": False,
|
||||||
|
"options": {"temperature": temperature},
|
||||||
|
}
|
||||||
|
response = requests.post(
|
||||||
|
f"{base_url.rstrip('/')}/api/chat",
|
||||||
|
json=payload,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}")
|
||||||
|
return response.json()
|
||||||
|
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise LLMError(f"缺少 {provider_key} API Key (model={model})")
|
raise LLMError(f"缺少 {provider_key} API Key (model={model})")
|
||||||
return _request_openai(
|
return _request_openai_chat(
|
||||||
model,
|
|
||||||
prompt,
|
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
system=system,
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
if base_url:
|
|
||||||
return _request_ollama(
|
|
||||||
model,
|
|
||||||
prompt,
|
|
||||||
base_url=base_url,
|
|
||||||
temperature=temperature,
|
|
||||||
timeout=timeout,
|
|
||||||
system=system,
|
|
||||||
)
|
|
||||||
raise LLMError(f"不支持的 LLM provider: {endpoint.provider}")
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_response(text: str) -> str:
|
def _normalize_response(text: str) -> str:
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Iterable, List, Sequence, Tuple
|
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from .db import db_session
|
from .db import db_session
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
@ -42,6 +42,29 @@ def parse_field_path(path: str) -> Tuple[str, str] | None:
|
|||||||
class DataBroker:
|
class DataBroker:
|
||||||
"""Lightweight data access helper for agent/LLM consumption."""
|
"""Lightweight data access helper for agent/LLM consumption."""
|
||||||
|
|
||||||
|
FIELD_ALIASES: Dict[str, Dict[str, str]] = {
|
||||||
|
"daily": {
|
||||||
|
"volume": "vol",
|
||||||
|
"vol": "vol",
|
||||||
|
"turnover": "amount",
|
||||||
|
},
|
||||||
|
"daily_basic": {
|
||||||
|
"turnover": "turnover_rate",
|
||||||
|
"turnover_rate": "turnover_rate",
|
||||||
|
"turnover_rate_f": "turnover_rate_f",
|
||||||
|
"volume_ratio": "volume_ratio",
|
||||||
|
"pe": "pe",
|
||||||
|
"pb": "pb",
|
||||||
|
"ps": "ps",
|
||||||
|
"ps_ttm": "ps_ttm",
|
||||||
|
},
|
||||||
|
"stk_limit": {
|
||||||
|
"up": "up_limit",
|
||||||
|
"down": "down_limit",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
MAX_WINDOW: int = 120
|
||||||
|
|
||||||
def fetch_latest(
|
def fetch_latest(
|
||||||
self,
|
self,
|
||||||
ts_code: str,
|
ts_code: str,
|
||||||
@ -51,16 +74,18 @@ class DataBroker:
|
|||||||
"""Fetch the latest value (<= trade_date) for each requested field."""
|
"""Fetch the latest value (<= trade_date) for each requested field."""
|
||||||
|
|
||||||
grouped: Dict[str, List[str]] = {}
|
grouped: Dict[str, List[str]] = {}
|
||||||
|
field_map: Dict[Tuple[str, str], List[str]] = {}
|
||||||
for item in fields:
|
for item in fields:
|
||||||
if not item:
|
if not item:
|
||||||
continue
|
continue
|
||||||
normalized = _safe_split(str(item))
|
resolved = self.resolve_field(str(item))
|
||||||
if not normalized:
|
if not resolved:
|
||||||
continue
|
continue
|
||||||
table, column = normalized
|
table, column = resolved
|
||||||
grouped.setdefault(table, [])
|
grouped.setdefault(table, [])
|
||||||
if column not in grouped[table]:
|
if column not in grouped[table]:
|
||||||
grouped[table].append(column)
|
grouped[table].append(column)
|
||||||
|
field_map.setdefault((table, column), []).append(str(item))
|
||||||
|
|
||||||
if not grouped:
|
if not grouped:
|
||||||
return {}
|
return {}
|
||||||
@ -91,8 +116,8 @@ class DataBroker:
|
|||||||
value = row[column]
|
value = row[column]
|
||||||
if value is None:
|
if value is None:
|
||||||
continue
|
continue
|
||||||
key = f"{table}.{column}"
|
for original in field_map.get((table, column), [f"{table}.{column}"]):
|
||||||
results[key] = float(value)
|
results[original] = float(value)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def fetch_series(
|
def fetch_series(
|
||||||
@ -107,10 +132,19 @@ class DataBroker:
|
|||||||
|
|
||||||
if window <= 0:
|
if window <= 0:
|
||||||
return []
|
return []
|
||||||
if not (_is_safe_identifier(table) and _is_safe_identifier(column)):
|
window = min(window, self.MAX_WINDOW)
|
||||||
|
resolved_field = self.resolve_field(f"{table}.{column}")
|
||||||
|
if not resolved_field:
|
||||||
|
LOGGER.debug(
|
||||||
|
"时间序列字段不存在 table=%s column=%s",
|
||||||
|
table,
|
||||||
|
column,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
return []
|
return []
|
||||||
|
table, resolved = resolved_field
|
||||||
query = (
|
query = (
|
||||||
f"SELECT trade_date, {column} FROM {table} "
|
f"SELECT trade_date, {resolved} FROM {table} "
|
||||||
"WHERE ts_code = ? AND trade_date <= ? "
|
"WHERE ts_code = ? AND trade_date <= ? "
|
||||||
"ORDER BY trade_date DESC LIMIT ?"
|
"ORDER BY trade_date DESC LIMIT ?"
|
||||||
)
|
)
|
||||||
@ -128,7 +162,7 @@ class DataBroker:
|
|||||||
return []
|
return []
|
||||||
series: List[Tuple[str, float]] = []
|
series: List[Tuple[str, float]] = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
value = row[column]
|
value = row[resolved]
|
||||||
if value is None:
|
if value is None:
|
||||||
continue
|
continue
|
||||||
series.append((row["trade_date"], float(value)))
|
series.append((row["trade_date"], float(value)))
|
||||||
@ -163,3 +197,57 @@ class DataBroker:
|
|||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
return row is not None
|
return row is not None
|
||||||
|
|
||||||
|
def resolve_field(self, field: str) -> Optional[Tuple[str, str]]:
|
||||||
|
normalized = _safe_split(field)
|
||||||
|
if not normalized:
|
||||||
|
return None
|
||||||
|
table, column = normalized
|
||||||
|
resolved = self._resolve_column(table, column)
|
||||||
|
if not resolved:
|
||||||
|
LOGGER.debug(
|
||||||
|
"字段不存在 table=%s column=%s",
|
||||||
|
table,
|
||||||
|
column,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return table, resolved
|
||||||
|
|
||||||
|
def _get_table_columns(self, table: str) -> Optional[set[str]]:
|
||||||
|
if not _is_safe_identifier(table):
|
||||||
|
return None
|
||||||
|
cache = getattr(self, "_column_cache", None)
|
||||||
|
if cache is None:
|
||||||
|
cache = {}
|
||||||
|
self._column_cache = cache
|
||||||
|
if table in cache:
|
||||||
|
return cache[table]
|
||||||
|
try:
|
||||||
|
with db_session(read_only=True) as conn:
|
||||||
|
rows = conn.execute(f"PRAGMA table_info({table})").fetchall()
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.debug("获取表字段失败 table=%s err=%s", table, exc, extra=LOG_EXTRA)
|
||||||
|
cache[table] = None
|
||||||
|
return None
|
||||||
|
if not rows:
|
||||||
|
cache[table] = None
|
||||||
|
return None
|
||||||
|
columns = {row["name"] for row in rows if row["name"]}
|
||||||
|
cache[table] = columns
|
||||||
|
return columns
|
||||||
|
|
||||||
|
def _resolve_column(self, table: str, column: str) -> Optional[str]:
|
||||||
|
columns = self._get_table_columns(table)
|
||||||
|
if columns is None:
|
||||||
|
return None
|
||||||
|
alias_map = self.FIELD_ALIASES.get(table, {})
|
||||||
|
candidate = alias_map.get(column, column)
|
||||||
|
if candidate in columns:
|
||||||
|
return candidate
|
||||||
|
# Try lower-case or fallback alias normalization
|
||||||
|
lowered = candidate.lower()
|
||||||
|
for name in columns:
|
||||||
|
if name.lower() == lowered:
|
||||||
|
return name
|
||||||
|
return None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user