From ca7b249c2c8eae8e21b86d002b3eaa51948af7a3 Mon Sep 17 00:00:00 2001 From: sam Date: Sun, 28 Sep 2025 21:22:15 +0800 Subject: [PATCH] update --- README.md | 5 +- app/agents/departments.py | 354 ++++++++++++++++++++++++++++++++------ app/llm/client.py | 135 +++++++++++---- app/utils/data_access.py | 106 +++++++++++- 4 files changed, 497 insertions(+), 103 deletions(-) diff --git a/README.md b/README.md index 00e4454..4136666 100644 --- a/README.md +++ b/README.md @@ -36,9 +36,10 @@ 1. **配置声明角色**:在 `config.json`/`DepartmentSettings` 中补充 `description` 与 `data_scope`,`department_prompt()` 拼接角色指令,实现职责以 Prompt 管理而非硬编码。 2. **统一数据层**:新增 `DataBroker`(或同类工具)封装常用查询,代理与部门通过声明式 JSON 请求所需表/字段/窗口,由服务端执行并返回特征。 3. **双阶段 LLM 工作流**:第一阶段让 LLM 输出结构化 `data_requests`,服务端取数后将摘要回填,第二阶段再生成最终行动与解释,形成闭环。 -4. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。 +4. **函数式工具调用**:DeepSeek 等 OpenAI 兼容模型已通过 function calling 接口接入 `fetch_data` 工具,LLM 按 schema 返回字段/窗口请求,系统使用 `DataBroker` 校验并补数后回传 tool result,再继续对话生成最终意见,避免字段错误与手写 JSON 校验。 +5. **审计与前端联动**:把角色提示、数据请求与执行摘要写入 `agent_utils` 附加字段,使 Streamlit 能完整呈现“角色 → 请求 → 决策”的链条。 -目前部门 LLM 已支持通过在 JSON 中返回 `data_requests` 触发追加查询:系统会使用 `DataBroker` 验证字段后补充最近数据窗口,再带着查询结果进入下一轮提示,从而形成“请求→取数→复议”的闭环。 +目前部门 LLM 通过 function calling 暴露的 `fetch_data` 工具触发追加查询:模型按 schema 声明需要的 `table.column` 字段与窗口,系统使用 `DataBroker` 验证后补齐数据并作为工具响应回传,再带着查询结果进入下一轮提示,从而形成“请求 → 取数 → 复议”的闭环。 上述调整可在单个部门先行做 PoC,验证闭环能力后再推广至全部角色。 diff --git a/app/agents/departments.py b/app/agents/departments.py index 1ad11b9..2c13276 100644 --- a/app/agents/departments.py +++ b/app/agents/departments.py @@ -6,11 +6,11 @@ from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple from app.agents.base import AgentAction -from app.llm.client import run_llm_with_config +from app.llm.client import call_endpoint_with_messages, run_llm_with_config, LLMError from app.llm.prompts import department_prompt from app.utils.config import AppConfig, DepartmentSettings, LLMConfig from app.utils.logging import get_logger -from app.utils.data_access import DataBroker, parse_field_path +from app.utils.data_access import DataBroker LOGGER = get_logger(__name__) LOG_EXTRA = {"stage": "department"} @@ -85,74 +85,95 @@ class DepartmentAgent: "你是一个多智能体量化投研系统中的分部决策者,需要根据提供的结构化信息给出买卖意见。" ) llm_cfg = self._get_llm_config() - supplement_chunks: List[str] = [] + + if llm_cfg.strategy not in (None, "", "single") or llm_cfg.ensemble: + LOGGER.warning( + "部门 %s 当前配置不支持函数调用模式,回退至传统提示", + self.settings.code, + extra=LOG_EXTRA, + ) + return self._analyze_legacy(mutable_context, system_prompt) + + tools = self._build_tools() + messages: List[Dict[str, object]] = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append( + { + "role": "user", + "content": department_prompt(self.settings, mutable_context), + } + ) + transcript: List[str] = [] - delivered_requests = { + delivered_requests: set[Tuple[str, int]] = { (field, 1) for field in (mutable_context.raw.get("scope_values") or {}).keys() } - response = "" - decision_data: Dict[str, Any] = {} + primary_endpoint = llm_cfg.primary + final_message: Optional[Dict[str, Any]] = None + for round_idx in range(self._max_rounds): - supplement_text = "\n\n".join(chunk for chunk in supplement_chunks if chunk) - prompt = department_prompt(self.settings, mutable_context, supplements=supplement_text) try: - response = run_llm_with_config(llm_cfg, prompt, system=system_prompt) - except Exception as exc: # noqa: BLE001 - LOGGER.exception( - "部门 %s 调用 LLM 失败:%s", + response = call_endpoint_with_messages( + primary_endpoint, + messages, + tools=tools, + tool_choice="auto", + ) + except LLMError as exc: + LOGGER.warning( + "部门 %s 函数调用失败,回退传统提示:%s", self.settings.code, exc, extra=LOG_EXTRA, ) - return DepartmentDecision( - department=self.settings.code, - action=AgentAction.HOLD, - confidence=0.0, - summary=f"LLM 调用失败:{exc}", - raw_response=str(exc), - ) + return self._analyze_legacy(mutable_context, system_prompt) - transcript.append(response) - decision_data = _parse_department_response(response) - data_requests = _parse_data_requests(decision_data) - filtered_requests = [ - req - for req in data_requests - if (req.field, req.window) not in delivered_requests - ] + choice = (response.get("choices") or [{}])[0] + message = choice.get("message", {}) + transcript.append(_message_to_text(message)) - if filtered_requests and round_idx < self._max_rounds - 1: - lines, payload, delivered = self._fulfill_data_requests( - mutable_context, filtered_requests - ) - if payload: - supplement_chunks.append( - f"回合 {round_idx + 1} 追加数据:\n" + "\n".join(lines) + tool_calls = message.get("tool_calls") or [] + if tool_calls: + for call in tool_calls: + tool_response, delivered = self._handle_tool_call( + mutable_context, + call, + delivered_requests, + round_idx, ) - mutable_context.raw.setdefault("supplement_data", []).extend(payload) - mutable_context.raw.setdefault("supplement_rounds", []).append( + transcript.append( + json.dumps({"tool_response": tool_response}, ensure_ascii=False) + ) + messages.append( { - "round": round_idx + 1, - "requests": [req.__dict__ for req in filtered_requests], - "data": payload, + "role": "tool", + "tool_call_id": call.get("id"), + "name": call.get("function", {}).get("name"), + "content": json.dumps(tool_response, ensure_ascii=False), } ) delivered_requests.update(delivered) - decision_data.pop("data_requests", None) - continue - LOGGER.debug( - "部门 %s 数据请求无结果:%s", - self.settings.code, - filtered_requests, - extra=LOG_EXTRA, - ) - decision_data.pop("data_requests", None) + continue + + final_message = message break + if final_message is None: + LOGGER.warning( + "部门 %s 函数调用达到轮次上限仍未返回文本,使用最后一次消息", + self.settings.code, + extra=LOG_EXTRA, + ) + final_message = message + mutable_context.raw["supplement_transcript"] = list(transcript) + content_text = _extract_message_content(final_message) + decision_data = _parse_department_response(content_text) + action = _normalize_action(decision_data.get("action")) confidence = _clamp_float(decision_data.get("confidence"), default=0.5) summary = decision_data.get("summary") or decision_data.get("reason") or "" @@ -170,7 +191,7 @@ class DepartmentAgent: summary=summary or "未提供摘要", signals=[str(sig) for sig in signals if sig], risks=[str(risk) for risk in risks if risk], - raw_response="\n\n".join(transcript) if transcript else response, + raw_response=content_text, supplements=list(mutable_context.raw.get("supplement_data", [])), dialogue=list(transcript), ) @@ -208,16 +229,16 @@ class DepartmentAgent: field = req.field.strip() if not field: continue + resolved = self._broker.resolve_field(field) + if not resolved: + lines.append(f"- {field}: 字段不存在或不可用") + continue if req.window <= 1: if field not in latest_fields: latest_fields.append(field) delivered.add((field, 1)) continue - parsed = parse_field_path(field) - if not parsed: - lines.append(f"- {field}: 字段不合法,已忽略") - continue - series_requests.append((req, parsed)) + series_requests.append((req, resolved)) delivered.add((field, req.window)) if latest_fields: @@ -251,11 +272,186 @@ class DepartmentAgent: lines.append( f"- {req.field} (window={req.window}): (数据缺失)" ) - payload.append({"field": req.field, "window": req.window, "values": series}) + payload.append( + { + "field": req.field, + "window": req.window, + "values": [ + {"trade_date": dt, "value": val} + for dt, val in series + ], + } + ) return lines, payload, delivered + def _handle_tool_call( + self, + context: DepartmentContext, + call: Mapping[str, Any], + delivered_requests: set[Tuple[str, int]], + round_idx: int, + ) -> Tuple[Dict[str, Any], set[Tuple[str, int]]]: + function_block = call.get("function") or {} + name = function_block.get("name") or "" + if name != "fetch_data": + LOGGER.warning( + "部门 %s 收到未知工具调用:%s", + self.settings.code, + name, + extra=LOG_EXTRA, + ) + return { + "status": "error", + "message": f"未知工具 {name}", + }, set() + + args = _parse_tool_arguments(function_block.get("arguments")) + raw_requests = args.get("requests") or [] + requests: List[DataRequest] = [] + skipped: List[str] = [] + for item in raw_requests: + field = str(item.get("field", "")).strip() + if not field: + continue + try: + window = int(item.get("window", 1)) + except (TypeError, ValueError): + window = 1 + window = max(1, min(window, getattr(self._broker, "MAX_WINDOW", 120))) + key = (field, window) + if key in delivered_requests: + skipped.append(field) + continue + requests.append(DataRequest(field=field, window=window)) + + if not requests: + return { + "status": "ok", + "round": round_idx + 1, + "results": [], + "skipped": skipped, + }, set() + + lines, payload, delivered = self._fulfill_data_requests(context, requests) + if payload: + context.raw.setdefault("supplement_data", []).extend(payload) + context.raw.setdefault("supplement_rounds", []).append( + { + "round": round_idx + 1, + "requests": [req.__dict__ for req in requests], + "data": payload, + "notes": lines, + } + ) + if lines: + context.raw.setdefault("supplement_notes", []).append( + { + "round": round_idx + 1, + "lines": lines, + } + ) + + response_payload = { + "status": "ok", + "round": round_idx + 1, + "results": payload, + "notes": lines, + "skipped": skipped, + } + return response_payload, delivered + + + def _build_tools(self) -> List[Dict[str, Any]]: + max_window = getattr(self._broker, "MAX_WINDOW", 120) + return [ + { + "type": "function", + "function": { + "name": "fetch_data", + "description": ( + "根据字段请求数据库中的最新值或时间序列。支持 table.column 格式的字段," + "window 表示希望返回的最近数据点数量。" + ), + "parameters": { + "type": "object", + "properties": { + "requests": { + "type": "array", + "items": { + "type": "object", + "properties": { + "field": { + "type": "string", + "description": "数据字段,格式为 table.column", + }, + "window": { + "type": "integer", + "minimum": 1, + "maximum": max_window, + "description": "返回最近多少个数据点,默认为 1", + }, + }, + "required": ["field"], + }, + "minItems": 1, + } + }, + "required": ["requests"], + }, + }, + } + ] + + def _analyze_legacy( + self, + context: DepartmentContext, + system_prompt: str, + ) -> DepartmentDecision: + prompt = department_prompt(self.settings, context) + llm_cfg = self._get_llm_config() + try: + response = run_llm_with_config(llm_cfg, prompt, system=system_prompt) + except Exception as exc: # noqa: BLE001 + LOGGER.exception( + "部门 %s 调用 LLM 失败:%s", + self.settings.code, + exc, + extra=LOG_EXTRA, + ) + return DepartmentDecision( + department=self.settings.code, + action=AgentAction.HOLD, + confidence=0.0, + summary=f"LLM 调用失败:{exc}", + raw_response=str(exc), + ) + + context.raw["supplement_transcript"] = [response] + decision_data = _parse_department_response(response) + action = _normalize_action(decision_data.get("action")) + confidence = _clamp_float(decision_data.get("confidence"), default=0.5) + summary = decision_data.get("summary") or decision_data.get("reason") or "" + signals = decision_data.get("signals") or decision_data.get("rationale") or [] + if isinstance(signals, str): + signals = [signals] + risks = decision_data.get("risks") or decision_data.get("warnings") or [] + if isinstance(risks, str): + risks = [risks] + + decision = DepartmentDecision( + department=self.settings.code, + action=action, + confidence=confidence, + summary=summary or "未提供摘要", + signals=[str(sig) for sig in signals if sig], + risks=[str(risk) for risk in risks if risk], + raw_response=response, + supplements=list(context.raw.get("supplement_data", [])), + dialogue=[response], + ) + return decision def _ensure_mutable_context(context: DepartmentContext) -> DepartmentContext: if not isinstance(context.features, dict): context.features = dict(context.features or {}) @@ -302,6 +498,54 @@ def _parse_data_requests(payload: Mapping[str, Any]) -> List[DataRequest]: return requests +def _parse_tool_arguments(payload: Any) -> Dict[str, Any]: + if isinstance(payload, dict): + return dict(payload) + if isinstance(payload, str): + try: + data = json.loads(payload) + except json.JSONDecodeError: + LOGGER.debug("工具参数解析失败:%s", payload, extra=LOG_EXTRA) + return {} + if isinstance(data, dict): + return data + return {} + + +def _message_to_text(message: Mapping[str, Any]) -> str: + content = message.get("content") + if isinstance(content, list): + parts: List[str] = [] + for item in content: + if isinstance(item, Mapping) and "text" in item: + parts.append(str(item.get("text", ""))) + else: + parts.append(str(item)) + if parts: + return "".join(parts) + elif isinstance(content, str) and content.strip(): + return content + tool_calls = message.get("tool_calls") + if tool_calls: + return json.dumps({"tool_calls": tool_calls}, ensure_ascii=False) + return "" + + +def _extract_message_content(message: Mapping[str, Any]) -> str: + content = message.get("content") + if isinstance(content, list): + texts = [ + str(item.get("text", "")) + for item in content + if isinstance(item, Mapping) and "text" in item + ] + if texts: + return "".join(texts) + if isinstance(content, str): + return content + return json.dumps(message, ensure_ascii=False) + + class DepartmentManager: """Orchestrates all departments defined in configuration.""" diff --git a/app/llm/client.py b/app/llm/client.py index 5b11f3c..596c003 100644 --- a/app/llm/client.py +++ b/app/llm/client.py @@ -11,6 +11,8 @@ import requests from app.utils.config import ( DEFAULT_LLM_BASE_URLS, DEFAULT_LLM_MODELS, + DEFAULT_LLM_TEMPERATURES, + DEFAULT_LLM_TIMEOUTS, LLMConfig, LLMEndpoint, get_config, @@ -34,8 +36,8 @@ def _default_model(provider: str) -> str: return DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"]) -def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, str]]: - messages: List[Dict[str, str]] = [] +def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, object]]: + messages: List[Dict[str, object]] = [] if system: messages.append({"role": "system", "content": system}) messages.append({"role": "user", "content": prompt}) @@ -70,38 +72,39 @@ def _request_ollama( return str(content) -def _request_openai( - model: str, - prompt: str, +def _request_openai_chat( *, base_url: str, api_key: str, + model: str, + messages: List[Dict[str, object]], temperature: float, timeout: float, - system: Optional[str], -) -> str: + tools: Optional[List[Dict[str, object]]] = None, + tool_choice: Optional[object] = None, +) -> Dict[str, object]: url = f"{base_url.rstrip('/')}/v1/chat/completions" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } - payload = { + payload: Dict[str, object] = { "model": model, - "messages": _build_messages(prompt, system), + "messages": messages, "temperature": temperature, } + if tools: + payload["tools"] = tools + if tool_choice is not None: + payload["tool_choice"] = tool_choice LOGGER.debug("调用 OpenAI 兼容接口: %s %s", model, url, extra=LOG_EXTRA) response = requests.post(url, headers=headers, json=payload, timeout=timeout) if response.status_code != 200: raise LLMError(f"OpenAI API 调用失败: {response.status_code} {response.text}") - data = response.json() - try: - return data["choices"][0]["message"]["content"].strip() - except (KeyError, IndexError) as exc: - raise LLMError(f"OpenAI 响应解析失败: {json.dumps(data, ensure_ascii=False)}") from exc + return response.json() -def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str: +def _prepare_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]: cfg = get_config() provider_key = (endpoint.provider or "ollama").lower() provider_cfg = cfg.llm_providers.get(provider_key) @@ -128,14 +131,30 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> else: base_url = base_url or _default_base_url(provider_key) model = model or _default_model(provider_key) + if temperature is None: temperature = DEFAULT_LLM_TEMPERATURES.get(provider_key, 0.2) if timeout is None: timeout = DEFAULT_LLM_TIMEOUTS.get(provider_key, 30.0) mode = "ollama" if provider_key == "ollama" else "openai" - temperature = max(0.0, min(float(temperature), 2.0)) - timeout = max(5.0, float(timeout)) + return { + "provider_key": provider_key, + "mode": mode, + "base_url": base_url, + "api_key": api_key, + "model": model, + "temperature": max(0.0, min(float(temperature), 2.0)), + "timeout": max(5.0, float(timeout)), + "prompt_template": prompt_template, + } + + +def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str: + resolved = _prepare_endpoint(endpoint) + provider_key = resolved["provider_key"] + mode = resolved["mode"] + prompt_template = resolved["prompt_template"] if prompt_template: try: @@ -143,6 +162,40 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> except Exception: # noqa: BLE001 LOGGER.warning("Prompt 模板格式化失败,使用原始 prompt", extra=LOG_EXTRA) + messages = _build_messages(prompt, system) + response = call_endpoint_with_messages( + endpoint, + messages, + tools=None, + ) + if mode == "ollama": + message = response.get("message") or {} + content = message.get("content", "") + if isinstance(content, list): + return "".join(chunk.get("text", "") or chunk.get("content", "") for chunk in content) + return str(content) + try: + return response["choices"][0]["message"]["content"].strip() + except (KeyError, IndexError) as exc: + raise LLMError(f"OpenAI 响应解析失败: {json.dumps(response, ensure_ascii=False)}") from exc + + +def call_endpoint_with_messages( + endpoint: LLMEndpoint, + messages: List[Dict[str, object]], + *, + tools: Optional[List[Dict[str, object]]] = None, + tool_choice: Optional[object] = None, +) -> Dict[str, object]: + resolved = _prepare_endpoint(endpoint) + provider_key = resolved["provider_key"] + mode = resolved["mode"] + base_url = resolved["base_url"] + model = resolved["model"] + temperature = resolved["temperature"] + timeout = resolved["timeout"] + api_key = resolved["api_key"] + LOGGER.info( "触发 LLM 请求:provider=%s model=%s base=%s", provider_key, @@ -151,28 +204,36 @@ def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> extra=LOG_EXTRA, ) - if mode != "ollama": - if not api_key: - raise LLMError(f"缺少 {provider_key} API Key (model={model})") - return _request_openai( - model, - prompt, - base_url=base_url, - api_key=api_key, - temperature=temperature, + if mode == "ollama": + if tools: + raise LLMError("当前 provider 不支持函数调用/工具模式") + payload = { + "model": model, + "messages": messages, + "stream": False, + "options": {"temperature": temperature}, + } + response = requests.post( + f"{base_url.rstrip('/')}/api/chat", + json=payload, timeout=timeout, - system=system, ) - if base_url: - return _request_ollama( - model, - prompt, - base_url=base_url, - temperature=temperature, - timeout=timeout, - system=system, - ) - raise LLMError(f"不支持的 LLM provider: {endpoint.provider}") + if response.status_code != 200: + raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}") + return response.json() + + if not api_key: + raise LLMError(f"缺少 {provider_key} API Key (model={model})") + return _request_openai_chat( + base_url=base_url, + api_key=api_key, + model=model, + messages=messages, + temperature=temperature, + timeout=timeout, + tools=tools, + tool_choice=tool_choice, + ) def _normalize_response(text: str) -> str: diff --git a/app/utils/data_access.py b/app/utils/data_access.py index c63d345..a12930f 100644 --- a/app/utils/data_access.py +++ b/app/utils/data_access.py @@ -3,7 +3,7 @@ from __future__ import annotations import re from dataclasses import dataclass -from typing import Dict, Iterable, List, Sequence, Tuple +from typing import Dict, Iterable, List, Optional, Sequence, Tuple from .db import db_session from .logging import get_logger @@ -42,6 +42,29 @@ def parse_field_path(path: str) -> Tuple[str, str] | None: class DataBroker: """Lightweight data access helper for agent/LLM consumption.""" + FIELD_ALIASES: Dict[str, Dict[str, str]] = { + "daily": { + "volume": "vol", + "vol": "vol", + "turnover": "amount", + }, + "daily_basic": { + "turnover": "turnover_rate", + "turnover_rate": "turnover_rate", + "turnover_rate_f": "turnover_rate_f", + "volume_ratio": "volume_ratio", + "pe": "pe", + "pb": "pb", + "ps": "ps", + "ps_ttm": "ps_ttm", + }, + "stk_limit": { + "up": "up_limit", + "down": "down_limit", + }, + } + MAX_WINDOW: int = 120 + def fetch_latest( self, ts_code: str, @@ -51,16 +74,18 @@ class DataBroker: """Fetch the latest value (<= trade_date) for each requested field.""" grouped: Dict[str, List[str]] = {} + field_map: Dict[Tuple[str, str], List[str]] = {} for item in fields: if not item: continue - normalized = _safe_split(str(item)) - if not normalized: + resolved = self.resolve_field(str(item)) + if not resolved: continue - table, column = normalized + table, column = resolved grouped.setdefault(table, []) if column not in grouped[table]: grouped[table].append(column) + field_map.setdefault((table, column), []).append(str(item)) if not grouped: return {} @@ -91,8 +116,8 @@ class DataBroker: value = row[column] if value is None: continue - key = f"{table}.{column}" - results[key] = float(value) + for original in field_map.get((table, column), [f"{table}.{column}"]): + results[original] = float(value) return results def fetch_series( @@ -107,10 +132,19 @@ class DataBroker: if window <= 0: return [] - if not (_is_safe_identifier(table) and _is_safe_identifier(column)): + window = min(window, self.MAX_WINDOW) + resolved_field = self.resolve_field(f"{table}.{column}") + if not resolved_field: + LOGGER.debug( + "时间序列字段不存在 table=%s column=%s", + table, + column, + extra=LOG_EXTRA, + ) return [] + table, resolved = resolved_field query = ( - f"SELECT trade_date, {column} FROM {table} " + f"SELECT trade_date, {resolved} FROM {table} " "WHERE ts_code = ? AND trade_date <= ? " "ORDER BY trade_date DESC LIMIT ?" ) @@ -128,7 +162,7 @@ class DataBroker: return [] series: List[Tuple[str, float]] = [] for row in rows: - value = row[column] + value = row[resolved] if value is None: continue series.append((row["trade_date"], float(value))) @@ -163,3 +197,57 @@ class DataBroker: ) return False return row is not None + + def resolve_field(self, field: str) -> Optional[Tuple[str, str]]: + normalized = _safe_split(field) + if not normalized: + return None + table, column = normalized + resolved = self._resolve_column(table, column) + if not resolved: + LOGGER.debug( + "字段不存在 table=%s column=%s", + table, + column, + extra=LOG_EXTRA, + ) + return None + return table, resolved + + def _get_table_columns(self, table: str) -> Optional[set[str]]: + if not _is_safe_identifier(table): + return None + cache = getattr(self, "_column_cache", None) + if cache is None: + cache = {} + self._column_cache = cache + if table in cache: + return cache[table] + try: + with db_session(read_only=True) as conn: + rows = conn.execute(f"PRAGMA table_info({table})").fetchall() + except Exception as exc: # noqa: BLE001 + LOGGER.debug("获取表字段失败 table=%s err=%s", table, exc, extra=LOG_EXTRA) + cache[table] = None + return None + if not rows: + cache[table] = None + return None + columns = {row["name"] for row in rows if row["name"]} + cache[table] = columns + return columns + + def _resolve_column(self, table: str, column: str) -> Optional[str]: + columns = self._get_table_columns(table) + if columns is None: + return None + alias_map = self.FIELD_ALIASES.get(table, {}) + candidate = alias_map.get(column, column) + if candidate in columns: + return candidate + # Try lower-case or fallback alias normalization + lowered = candidate.lower() + for name in columns: + if name.lower() == lowered: + return name + return None