llm-quant/app/llm/client.py
2025-09-27 20:30:01 +08:00

226 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Unified LLM client supporting Ollama and OpenAI compatible APIs."""
from __future__ import annotations
import json
from collections import Counter
from dataclasses import asdict
from typing import Dict, Iterable, List, Optional
import requests
from app.utils.config import DEFAULT_LLM_MODELS, LLMEndpoint, get_config
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "llm"}
class LLMError(RuntimeError):
"""Raised when LLM provider returns an error response."""
def _default_base_url(provider: str) -> str:
if provider == "ollama":
return "http://localhost:11434"
if provider == "deepseek":
return "https://api.deepseek.com"
if provider == "wenxin":
return "https://aip.baidubce.com"
return "https://api.openai.com"
def _default_model(provider: str) -> str:
provider = (provider or "").lower()
return DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"])
def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, str]]:
messages: List[Dict[str, str]] = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
return messages
def _request_ollama(
model: str,
prompt: str,
*,
base_url: str,
temperature: float,
timeout: float,
system: Optional[str],
) -> str:
url = f"{base_url.rstrip('/')}/api/chat"
payload = {
"model": model,
"messages": _build_messages(prompt, system),
"stream": False,
"options": {"temperature": temperature},
}
LOGGER.debug("调用 Ollama: %s %s", model, url, extra=LOG_EXTRA)
response = requests.post(url, json=payload, timeout=timeout)
if response.status_code != 200:
raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}")
data = response.json()
message = data.get("message", {})
content = message.get("content", "")
if isinstance(content, list):
return "".join(chunk.get("text", "") or chunk.get("content", "") for chunk in content)
return str(content)
def _request_openai(
model: str,
prompt: str,
*,
base_url: str,
api_key: str,
temperature: float,
timeout: float,
system: Optional[str],
) -> str:
url = f"{base_url.rstrip('/')}/v1/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
payload = {
"model": model,
"messages": _build_messages(prompt, system),
"temperature": temperature,
}
LOGGER.debug("调用 OpenAI 兼容接口: %s %s", model, url, extra=LOG_EXTRA)
response = requests.post(url, headers=headers, json=payload, timeout=timeout)
if response.status_code != 200:
raise LLMError(f"OpenAI API 调用失败: {response.status_code} {response.text}")
data = response.json()
try:
return data["choices"][0]["message"]["content"].strip()
except (KeyError, IndexError) as exc:
raise LLMError(f"OpenAI 响应解析失败: {json.dumps(data, ensure_ascii=False)}") from exc
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
provider = (endpoint.provider or "ollama").lower()
base_url = endpoint.base_url or _default_base_url(provider)
model = endpoint.model or _default_model(provider)
temperature = max(0.0, min(endpoint.temperature, 2.0))
timeout = max(5.0, endpoint.timeout or 30.0)
LOGGER.info(
"触发 LLM 请求provider=%s model=%s base=%s",
provider,
model,
base_url,
extra=LOG_EXTRA,
)
if provider in {"openai", "deepseek", "wenxin"}:
api_key = endpoint.api_key
if not api_key:
raise LLMError(f"缺少 {provider} API Key (model={model})")
return _request_openai(
model,
prompt,
base_url=base_url,
api_key=api_key,
temperature=temperature,
timeout=timeout,
system=system,
)
if provider == "ollama":
return _request_ollama(
model,
prompt,
base_url=base_url,
temperature=temperature,
timeout=timeout,
system=system,
)
raise LLMError(f"不支持的 LLM provider: {endpoint.provider}")
def _normalize_response(text: str) -> str:
return " ".join(text.strip().split())
def run_llm(prompt: str, *, system: Optional[str] = None) -> str:
"""Execute the configured LLM strategy with the given prompt."""
settings = get_config().llm
if settings.strategy == "majority":
return _run_majority_vote(settings, prompt, system)
return _call_endpoint(settings.primary, prompt, system)
def _run_majority_vote(config, prompt: str, system: Optional[str]) -> str:
endpoints: List[LLMEndpoint] = [config.primary] + list(config.ensemble)
responses: List[Dict[str, str]] = []
failures: List[str] = []
for idx, endpoint in enumerate(endpoints, start=1):
try:
result = _call_endpoint(endpoint, prompt, system)
responses.append({
"provider": endpoint.provider,
"model": endpoint.model,
"raw": result,
"normalized": _normalize_response(result),
})
except Exception as exc: # noqa: BLE001
summary = f"{endpoint.provider}:{endpoint.model} -> {exc}"
failures.append(summary)
LOGGER.warning("LLM 调用失败:%s", summary, extra=LOG_EXTRA)
if not responses:
raise LLMError("所有 LLM 调用均失败,无法返回结果。")
threshold = max(1, config.majority_threshold)
threshold = min(threshold, len(responses))
counter = Counter(item["normalized"] for item in responses)
top_value, top_count = counter.most_common(1)[0]
if top_count >= threshold:
chosen_raw = next(item["raw"] for item in responses if item["normalized"] == top_value)
LOGGER.info(
"LLM 多模型投票通过value=%s votes=%s/%s threshold=%s",
top_value[:80],
top_count,
len(responses),
threshold,
extra=LOG_EXTRA,
)
return chosen_raw
LOGGER.info(
"LLM 多模型投票未达门槛votes=%s/%s threshold=%s,返回首个结果",
top_count,
len(responses),
threshold,
extra=LOG_EXTRA,
)
if failures:
LOGGER.warning("LLM 调用失败列表:%s", failures, extra=LOG_EXTRA)
return responses[0]["raw"]
def llm_config_snapshot() -> Dict[str, object]:
"""Return a sanitized snapshot of current LLM configuration for debugging."""
settings = get_config().llm
primary = asdict(settings.primary)
if primary.get("api_key"):
primary["api_key"] = "***"
ensemble = []
for endpoint in settings.ensemble:
record = asdict(endpoint)
if record.get("api_key"):
record["api_key"] = "***"
ensemble.append(record)
return {
"strategy": settings.strategy,
"majority_threshold": settings.majority_threshold,
"primary": primary,
"ensemble": ensemble,
}