226 lines
7.2 KiB
Python
226 lines
7.2 KiB
Python
"""Unified LLM client supporting Ollama and OpenAI compatible APIs."""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from collections import Counter
|
||
from dataclasses import asdict
|
||
from typing import Dict, Iterable, List, Optional
|
||
|
||
import requests
|
||
|
||
from app.utils.config import DEFAULT_LLM_MODELS, LLMEndpoint, get_config
|
||
from app.utils.logging import get_logger
|
||
|
||
LOGGER = get_logger(__name__)
|
||
LOG_EXTRA = {"stage": "llm"}
|
||
|
||
class LLMError(RuntimeError):
|
||
"""Raised when LLM provider returns an error response."""
|
||
|
||
|
||
def _default_base_url(provider: str) -> str:
|
||
if provider == "ollama":
|
||
return "http://localhost:11434"
|
||
if provider == "deepseek":
|
||
return "https://api.deepseek.com"
|
||
if provider == "wenxin":
|
||
return "https://aip.baidubce.com"
|
||
return "https://api.openai.com"
|
||
|
||
|
||
def _default_model(provider: str) -> str:
|
||
provider = (provider or "").lower()
|
||
return DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"])
|
||
|
||
|
||
def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, str]]:
|
||
messages: List[Dict[str, str]] = []
|
||
if system:
|
||
messages.append({"role": "system", "content": system})
|
||
messages.append({"role": "user", "content": prompt})
|
||
return messages
|
||
|
||
|
||
def _request_ollama(
|
||
model: str,
|
||
prompt: str,
|
||
*,
|
||
base_url: str,
|
||
temperature: float,
|
||
timeout: float,
|
||
system: Optional[str],
|
||
) -> str:
|
||
url = f"{base_url.rstrip('/')}/api/chat"
|
||
payload = {
|
||
"model": model,
|
||
"messages": _build_messages(prompt, system),
|
||
"stream": False,
|
||
"options": {"temperature": temperature},
|
||
}
|
||
LOGGER.debug("调用 Ollama: %s %s", model, url, extra=LOG_EXTRA)
|
||
response = requests.post(url, json=payload, timeout=timeout)
|
||
if response.status_code != 200:
|
||
raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}")
|
||
data = response.json()
|
||
message = data.get("message", {})
|
||
content = message.get("content", "")
|
||
if isinstance(content, list):
|
||
return "".join(chunk.get("text", "") or chunk.get("content", "") for chunk in content)
|
||
return str(content)
|
||
|
||
|
||
def _request_openai(
|
||
model: str,
|
||
prompt: str,
|
||
*,
|
||
base_url: str,
|
||
api_key: str,
|
||
temperature: float,
|
||
timeout: float,
|
||
system: Optional[str],
|
||
) -> str:
|
||
url = f"{base_url.rstrip('/')}/v1/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
payload = {
|
||
"model": model,
|
||
"messages": _build_messages(prompt, system),
|
||
"temperature": temperature,
|
||
}
|
||
LOGGER.debug("调用 OpenAI 兼容接口: %s %s", model, url, extra=LOG_EXTRA)
|
||
response = requests.post(url, headers=headers, json=payload, timeout=timeout)
|
||
if response.status_code != 200:
|
||
raise LLMError(f"OpenAI API 调用失败: {response.status_code} {response.text}")
|
||
data = response.json()
|
||
try:
|
||
return data["choices"][0]["message"]["content"].strip()
|
||
except (KeyError, IndexError) as exc:
|
||
raise LLMError(f"OpenAI 响应解析失败: {json.dumps(data, ensure_ascii=False)}") from exc
|
||
|
||
|
||
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
|
||
provider = (endpoint.provider or "ollama").lower()
|
||
base_url = endpoint.base_url or _default_base_url(provider)
|
||
model = endpoint.model or _default_model(provider)
|
||
temperature = max(0.0, min(endpoint.temperature, 2.0))
|
||
timeout = max(5.0, endpoint.timeout or 30.0)
|
||
|
||
LOGGER.info(
|
||
"触发 LLM 请求:provider=%s model=%s base=%s",
|
||
provider,
|
||
model,
|
||
base_url,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
|
||
if provider in {"openai", "deepseek", "wenxin"}:
|
||
api_key = endpoint.api_key
|
||
if not api_key:
|
||
raise LLMError(f"缺少 {provider} API Key (model={model})")
|
||
return _request_openai(
|
||
model,
|
||
prompt,
|
||
base_url=base_url,
|
||
api_key=api_key,
|
||
temperature=temperature,
|
||
timeout=timeout,
|
||
system=system,
|
||
)
|
||
if provider == "ollama":
|
||
return _request_ollama(
|
||
model,
|
||
prompt,
|
||
base_url=base_url,
|
||
temperature=temperature,
|
||
timeout=timeout,
|
||
system=system,
|
||
)
|
||
raise LLMError(f"不支持的 LLM provider: {endpoint.provider}")
|
||
|
||
|
||
def _normalize_response(text: str) -> str:
|
||
return " ".join(text.strip().split())
|
||
|
||
|
||
def run_llm(prompt: str, *, system: Optional[str] = None) -> str:
|
||
"""Execute the configured LLM strategy with the given prompt."""
|
||
|
||
settings = get_config().llm
|
||
if settings.strategy == "majority":
|
||
return _run_majority_vote(settings, prompt, system)
|
||
return _call_endpoint(settings.primary, prompt, system)
|
||
|
||
|
||
def _run_majority_vote(config, prompt: str, system: Optional[str]) -> str:
|
||
endpoints: List[LLMEndpoint] = [config.primary] + list(config.ensemble)
|
||
responses: List[Dict[str, str]] = []
|
||
failures: List[str] = []
|
||
|
||
for idx, endpoint in enumerate(endpoints, start=1):
|
||
try:
|
||
result = _call_endpoint(endpoint, prompt, system)
|
||
responses.append({
|
||
"provider": endpoint.provider,
|
||
"model": endpoint.model,
|
||
"raw": result,
|
||
"normalized": _normalize_response(result),
|
||
})
|
||
except Exception as exc: # noqa: BLE001
|
||
summary = f"{endpoint.provider}:{endpoint.model} -> {exc}"
|
||
failures.append(summary)
|
||
LOGGER.warning("LLM 调用失败:%s", summary, extra=LOG_EXTRA)
|
||
|
||
if not responses:
|
||
raise LLMError("所有 LLM 调用均失败,无法返回结果。")
|
||
|
||
threshold = max(1, config.majority_threshold)
|
||
threshold = min(threshold, len(responses))
|
||
|
||
counter = Counter(item["normalized"] for item in responses)
|
||
top_value, top_count = counter.most_common(1)[0]
|
||
if top_count >= threshold:
|
||
chosen_raw = next(item["raw"] for item in responses if item["normalized"] == top_value)
|
||
LOGGER.info(
|
||
"LLM 多模型投票通过:value=%s votes=%s/%s threshold=%s",
|
||
top_value[:80],
|
||
top_count,
|
||
len(responses),
|
||
threshold,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
return chosen_raw
|
||
|
||
LOGGER.info(
|
||
"LLM 多模型投票未达门槛:votes=%s/%s threshold=%s,返回首个结果",
|
||
top_count,
|
||
len(responses),
|
||
threshold,
|
||
extra=LOG_EXTRA,
|
||
)
|
||
if failures:
|
||
LOGGER.warning("LLM 调用失败列表:%s", failures, extra=LOG_EXTRA)
|
||
return responses[0]["raw"]
|
||
|
||
|
||
def llm_config_snapshot() -> Dict[str, object]:
|
||
"""Return a sanitized snapshot of current LLM configuration for debugging."""
|
||
|
||
settings = get_config().llm
|
||
primary = asdict(settings.primary)
|
||
if primary.get("api_key"):
|
||
primary["api_key"] = "***"
|
||
ensemble = []
|
||
for endpoint in settings.ensemble:
|
||
record = asdict(endpoint)
|
||
if record.get("api_key"):
|
||
record["api_key"] = "***"
|
||
ensemble.append(record)
|
||
return {
|
||
"strategy": settings.strategy,
|
||
"majority_threshold": settings.majority_threshold,
|
||
"primary": primary,
|
||
"ensemble": ensemble,
|
||
}
|