llm-quant/app/llm/client.py

587 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Unified LLM client supporting Ollama and OpenAI compatible APIs."""
from __future__ import annotations
import json
from collections import Counter
import time
from dataclasses import asdict
from typing import Any, Dict, Iterable, List, Optional
import requests
from .context import ContextManager, Message
from .templates import TemplateRegistry
from .cost import configure_cost_limits, get_cost_controller, budget_available
from .cache import build_cache_key, is_cacheable, llm_cache
from .rate_limit import RateLimiter
from app.utils.config import (
DEFAULT_LLM_BASE_URLS,
DEFAULT_LLM_MODELS,
DEFAULT_LLM_TEMPERATURES,
DEFAULT_LLM_TIMEOUTS,
DEFAULT_LLM_MODEL_OPTIONS,
LLMConfig,
LLMEndpoint,
get_config,
)
from app.llm.metrics import record_call, record_cache_hit, record_template_usage
from app.utils.logging import get_logger
LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "llm"}
RATE_LIMITER = RateLimiter()
class LLMError(RuntimeError):
"""Raised when LLM provider returns an error response."""
def _default_base_url(provider: str) -> str:
provider = (provider or "openai").lower()
return DEFAULT_LLM_BASE_URLS.get(provider, DEFAULT_LLM_BASE_URLS["openai"])
def _default_model(provider: str) -> str:
provider = (provider or "").lower()
return DEFAULT_LLM_MODELS.get(provider, DEFAULT_LLM_MODELS["ollama"])
def _build_messages(prompt: str, system: Optional[str] = None) -> List[Dict[str, object]]:
messages: List[Dict[str, object]] = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
return messages
def _request_ollama(
model: str,
prompt: str,
*,
base_url: str,
temperature: float,
timeout: float,
system: Optional[str],
) -> str:
url = f"{base_url.rstrip('/')}/api/chat"
payload = {
"model": model,
"messages": _build_messages(prompt, system),
"stream": False,
"options": {"temperature": temperature},
}
LOGGER.debug("调用 Ollama: %s %s", model, url, extra=LOG_EXTRA)
response = requests.post(url, json=payload, timeout=timeout)
if response.status_code != 200:
raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}")
data = response.json()
message = data.get("message", {})
content = message.get("content", "")
if isinstance(content, list):
return "".join(chunk.get("text", "") or chunk.get("content", "") for chunk in content)
return str(content)
def _request_openai_chat(
*,
base_url: str,
api_key: str,
model: str,
messages: List[Dict[str, object]],
temperature: float,
timeout: float,
tools: Optional[List[Dict[str, object]]] = None,
tool_choice: Optional[object] = None,
) -> Dict[str, object]:
url = f"{base_url.rstrip('/')}/v1/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
payload: Dict[str, object] = {
"model": model,
"messages": messages,
"temperature": temperature,
}
if tools:
payload["tools"] = tools
if tool_choice is not None:
payload["tool_choice"] = tool_choice
LOGGER.debug("调用 OpenAI 兼容接口: %s %s", model, url, extra=LOG_EXTRA)
response = requests.post(url, headers=headers, json=payload, timeout=timeout)
if response.status_code != 200:
raise LLMError(f"OpenAI API 调用失败: {response.status_code} {response.text}")
return response.json()
def resolve_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
cfg = get_config()
provider_key = (endpoint.provider or "ollama").lower()
provider_cfg = cfg.llm_providers.get(provider_key)
base_url = endpoint.base_url
api_key = endpoint.api_key
model = endpoint.model
temperature = endpoint.temperature
timeout = endpoint.timeout
prompt_template = endpoint.prompt_template
def _safe_int(value: object, fallback: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return fallback
rate_limit_per_minute = 0
rate_limit_burst = 0
cache_enabled = True
cache_ttl_seconds = 0
if provider_cfg:
if not provider_cfg.enabled:
raise LLMError(f"Provider {provider_key} 已被禁用")
base_url = base_url or provider_cfg.base_url or _default_base_url(provider_key)
api_key = api_key or provider_cfg.api_key
model = model or provider_cfg.default_model or (provider_cfg.models[0] if provider_cfg.models else _default_model(provider_key))
if temperature is None:
temperature = provider_cfg.default_temperature
if timeout is None:
timeout = provider_cfg.default_timeout
prompt_template = prompt_template or (provider_cfg.prompt_template or None)
mode = provider_cfg.mode or ("ollama" if provider_key == "ollama" else "openai")
rate_limit_per_minute = max(0, _safe_int(provider_cfg.rate_limit_per_minute, 0))
rate_limit_burst = provider_cfg.rate_limit_burst
rate_limit_burst = _safe_int(rate_limit_burst, rate_limit_per_minute or 0)
if rate_limit_per_minute > 0:
rate_limit_burst = max(1, rate_limit_burst or rate_limit_per_minute)
else:
rate_limit_burst = max(0, rate_limit_burst)
cache_enabled = bool(provider_cfg.cache_enabled)
cache_ttl_seconds = max(0, _safe_int(provider_cfg.cache_ttl_seconds, 0))
else:
base_url = base_url or _default_base_url(provider_key)
model = model or _default_model(provider_key)
if temperature is None:
temperature = DEFAULT_LLM_TEMPERATURES.get(provider_key, 0.2)
if timeout is None:
timeout = DEFAULT_LLM_TIMEOUTS.get(provider_key, 30.0)
mode = "ollama" if provider_key == "ollama" else "openai"
defaults = DEFAULT_LLM_MODEL_OPTIONS.get(provider_key, {})
rate_limit_per_minute = max(0, _safe_int(defaults.get("rate_limit_per_minute"), 0))
rate_limit_burst = _safe_int(defaults.get("rate_limit_burst"), rate_limit_per_minute or 0)
if rate_limit_per_minute > 0:
rate_limit_burst = max(1, rate_limit_burst or rate_limit_per_minute)
else:
rate_limit_burst = max(0, rate_limit_burst)
cache_enabled = bool(defaults.get("cache_enabled", True))
cache_ttl_seconds = max(0, _safe_int(defaults.get("cache_ttl_seconds"), 0))
return {
"provider_key": provider_key,
"mode": mode,
"base_url": base_url,
"api_key": api_key,
"model": model,
"temperature": max(0.0, min(float(temperature), 2.0)),
"timeout": max(5.0, float(timeout)),
"prompt_template": prompt_template,
"rate_limit_per_minute": rate_limit_per_minute,
"rate_limit_burst": rate_limit_burst,
"cache_enabled": cache_enabled,
"cache_ttl_seconds": cache_ttl_seconds,
}
def _call_endpoint(endpoint: LLMEndpoint, prompt: str, system: Optional[str]) -> str:
resolved = resolve_endpoint(endpoint)
provider_key = resolved["provider_key"]
mode = resolved["mode"]
prompt_template = resolved["prompt_template"]
if prompt_template:
try:
prompt = prompt_template.format(prompt=prompt)
except Exception: # noqa: BLE001
LOGGER.warning("Prompt 模板格式化失败,使用原始 prompt", extra=LOG_EXTRA)
messages = _build_messages(prompt, system)
response = call_endpoint_with_messages(
endpoint,
messages,
tools=None,
)
if mode == "ollama":
message = response.get("message") or {}
content = message.get("content", "")
if isinstance(content, list):
return "".join(chunk.get("text", "") or chunk.get("content", "") for chunk in content)
return str(content)
try:
return response["choices"][0]["message"]["content"].strip()
except (KeyError, IndexError) as exc:
raise LLMError(f"OpenAI 响应解析失败: {json.dumps(response, ensure_ascii=False)}") from exc
def call_endpoint_with_messages(
endpoint: LLMEndpoint,
messages: List[Dict[str, object]],
*,
tools: Optional[List[Dict[str, object]]] = None,
tool_choice: Optional[object] = None,
) -> Dict[str, object]:
resolved = resolve_endpoint(endpoint)
provider_key = resolved["provider_key"]
mode = resolved["mode"]
base_url = resolved["base_url"]
model = resolved["model"]
temperature = resolved["temperature"]
timeout = resolved["timeout"]
api_key = resolved["api_key"]
rate_limit_per_minute = max(0, int(resolved.get("rate_limit_per_minute") or 0))
rate_limit_burst = max(0, int(resolved.get("rate_limit_burst") or 0))
cache_enabled = bool(resolved.get("cache_enabled", True))
cache_ttl_seconds = max(0, int(resolved.get("cache_ttl_seconds") or 0))
if rate_limit_per_minute > 0:
if rate_limit_burst <= 0:
rate_limit_burst = rate_limit_per_minute
wait_time = RATE_LIMITER.acquire(provider_key, rate_limit_per_minute, rate_limit_burst)
if wait_time > 0:
LOGGER.debug(
"LLM 请求触发限速provider=%s wait=%.3fs",
provider_key,
wait_time,
extra=LOG_EXTRA,
)
time.sleep(wait_time)
cache_store = llm_cache()
cache_allowed = (
cache_enabled
and cache_ttl_seconds > 0
and cache_store.enabled
and is_cacheable(resolved, messages, tools)
)
cache_key: Optional[str] = None
if cache_allowed:
cache_key = build_cache_key(provider_key, resolved, messages, tools, tool_choice)
cached_payload = cache_store.get(cache_key)
if cached_payload is not None:
record_cache_hit(provider_key, model)
return cached_payload
cfg = get_config()
cost_cfg = getattr(cfg, "llm_cost", None)
enforce_cost = False
cost_controller = None
if cost_cfg and getattr(cost_cfg, "enabled", False):
try:
limits = cost_cfg.to_cost_limits()
except Exception as exc: # noqa: BLE001
LOGGER.warning(
"成本控制配置解析失败,将忽略限制: %s",
exc,
extra=LOG_EXTRA,
)
else:
configure_cost_limits(limits)
enforce_cost = True
if not budget_available():
raise LLMError("LLM 调用预算已耗尽,请稍后重试。")
cost_controller = get_cost_controller()
LOGGER.info(
"触发 LLM 请求provider=%s model=%s base=%s",
provider_key,
model,
base_url,
extra=LOG_EXTRA,
)
if mode == "ollama":
# Ollama supports function/tool calling via the /api/chat endpoint.
# Include `tools` and optional `tool_choice` in the payload when provided.
payload = {
"model": model,
"messages": messages,
"stream": False,
"options": {"temperature": temperature},
}
if tools:
# Ollama expects `tools` at the top level similar to OpenAI-compatible API
payload["tools"] = tools
if tool_choice is not None:
payload["tool_choice"] = tool_choice
start_time = time.perf_counter()
response = requests.post(
f"{base_url.rstrip('/')}/api/chat",
json=payload,
timeout=timeout,
)
duration = time.perf_counter() - start_time
if response.status_code != 200:
raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}")
data = response.json()
record_call(provider_key, model, duration=duration)
if enforce_cost and cost_controller:
cost_controller.record_usage(model or provider_key, 0, 0)
# Ollama may return `tool_calls` under message.tool_calls when tools are used.
# Return the raw response so callers can handle either OpenAI-like responses or
# Ollama's message structure with `tool_calls`.
if cache_allowed and cache_key:
cache_store.set(cache_key, data, ttl=cache_ttl_seconds)
return data
if not api_key:
raise LLMError(f"缺少 {provider_key} API Key (model={model})")
start_time = time.perf_counter()
data = _request_openai_chat(
base_url=base_url,
api_key=api_key,
model=model,
messages=messages,
temperature=temperature,
timeout=timeout,
tools=tools,
tool_choice=tool_choice,
)
duration = time.perf_counter() - start_time
usage = data.get("usage", {}) if isinstance(data, dict) else {}
prompt_tokens = usage.get("prompt_tokens") or usage.get("prompt_tokens_total")
completion_tokens = usage.get("completion_tokens") or usage.get("completion_tokens_total")
record_call(
provider_key,
model,
prompt_tokens,
completion_tokens,
duration=duration,
)
if enforce_cost and cost_controller:
prompt_count = int(prompt_tokens or 0)
completion_count = int(completion_tokens or 0)
within_limits = cost_controller.record_usage(model or provider_key, prompt_count, completion_count)
if not within_limits:
LOGGER.warning(
"LLM 成本预算已超限provider=%s model=%s",
provider_key,
model,
extra=LOG_EXTRA,
)
if cache_allowed and cache_key:
cache_store.set(cache_key, data, ttl=cache_ttl_seconds)
return data
def _normalize_response(text: str) -> str:
return " ".join(text.strip().split())
def run_llm(
prompt: str,
*,
system: Optional[str] = None,
context_id: Optional[str] = None,
template_id: Optional[str] = None,
template_vars: Optional[Dict[str, Any]] = None
) -> str:
"""Execute the globally configured LLM strategy with the given prompt.
Args:
prompt: Raw prompt string or template variable if template_id is provided
system: Optional system message
context_id: Optional context ID for conversation tracking
template_id: Optional template ID to use
template_vars: Variables to use with the template
"""
# Get config and prepare context
cfg = get_config()
if context_id:
context = ContextManager.get_context(context_id)
if not context:
context = ContextManager.create_context(context_id)
else:
context = None
# Apply template if specified
applied_template_version: Optional[str] = None
if template_id:
template = TemplateRegistry.get(template_id)
if not template:
raise ValueError(f"Template {template_id} not found")
applied_template_version = TemplateRegistry.get_active_version(template_id)
vars_dict = template_vars or {}
if isinstance(prompt, str):
vars_dict["prompt"] = prompt
elif isinstance(prompt, dict):
vars_dict.update(prompt)
prompt = template.format(vars_dict)
# Add to context if tracking
if context:
if system:
context.add_message(Message(role="system", content=system))
context.add_message(Message(role="user", content=prompt))
# Execute LLM call
response = run_llm_with_config(cfg.llm, prompt, system=system)
# Update context with response
if context:
context.add_message(Message(role="assistant", content=response))
if template_id:
record_template_usage(
template_id,
version=applied_template_version,
)
return response
def _run_majority_vote(config: LLMConfig, prompt: str, system: Optional[str]) -> str:
endpoints: List[LLMEndpoint] = [config.primary] + list(config.ensemble)
responses: List[Dict[str, str]] = []
failures: List[str] = []
for idx, endpoint in enumerate(endpoints, start=1):
try:
result = _call_endpoint(endpoint, prompt, system)
responses.append({
"provider": endpoint.provider,
"model": endpoint.model,
"raw": result,
"normalized": _normalize_response(result),
})
except Exception as exc: # noqa: BLE001
summary = f"{endpoint.provider}:{endpoint.model} -> {exc}"
failures.append(summary)
LOGGER.warning("LLM 调用失败:%s", summary, extra=LOG_EXTRA)
if not responses:
raise LLMError("所有 LLM 调用均失败,无法返回结果。")
threshold = max(1, config.majority_threshold)
threshold = min(threshold, len(responses))
counter = Counter(item["normalized"] for item in responses)
top_value, top_count = counter.most_common(1)[0]
if top_count >= threshold:
chosen_raw = next(item["raw"] for item in responses if item["normalized"] == top_value)
LOGGER.info(
"LLM 多模型投票通过value=%s votes=%s/%s threshold=%s",
top_value[:80],
top_count,
len(responses),
threshold,
extra=LOG_EXTRA,
)
return chosen_raw
LOGGER.info(
"LLM 多模型投票未达门槛votes=%s/%s threshold=%s,返回首个结果",
top_count,
len(responses),
threshold,
extra=LOG_EXTRA,
)
if failures:
LOGGER.warning("LLM 调用失败列表:%s", failures, extra=LOG_EXTRA)
return responses[0]["raw"]
def _run_leader_follow(config: LLMConfig, prompt: str, system: Optional[str]) -> str:
advisors: List[Dict[str, str]] = []
for endpoint in config.ensemble:
try:
raw = _call_endpoint(endpoint, prompt, system)
advisors.append(
{
"provider": endpoint.provider,
"model": endpoint.model or "",
"raw": raw,
}
)
except Exception as exc: # noqa: BLE001
LOGGER.warning(
"顾问模型调用失败:%s:%s -> %s",
endpoint.provider,
endpoint.model,
exc,
extra=LOG_EXTRA,
)
if not advisors:
LOGGER.info("领导者策略顾问为空,回退至主模型", extra=LOG_EXTRA)
return _call_endpoint(config.primary, prompt, system)
advisor_chunks = []
for idx, record in enumerate(advisors, start=1):
snippet = record["raw"].strip()
if len(snippet) > 1200:
snippet = snippet[:1200] + "..."
advisor_chunks.append(
f"顾问#{idx} ({record['provider']}:{record['model']}):\n{snippet}"
)
advisor_section = "\n\n".join(advisor_chunks)
leader_prompt = (
"【顾问模型意见】\n"
f"{advisor_section}\n\n"
"请在充分参考顾问模型观点的基础上,保持原始指令的输出格式进行最终回答。\n\n"
f"{prompt}"
)
LOGGER.info(
"领导者策略触发:顾问数量=%s",
len(advisors),
extra=LOG_EXTRA,
)
return _call_endpoint(config.primary, leader_prompt, system)
def run_llm_with_config(
config: LLMConfig,
prompt: str,
*,
system: Optional[str] = None,
) -> str:
"""Execute an LLM request using the provided configuration block."""
strategy = (config.strategy or "single").lower()
if strategy == "leader-follower":
strategy = "leader"
if strategy == "majority":
return _run_majority_vote(config, prompt, system)
if strategy == "leader":
return _run_leader_follow(config, prompt, system)
return _call_endpoint(config.primary, prompt, system)
def llm_config_snapshot() -> Dict[str, object]:
"""Return a sanitized snapshot of current LLM configuration for debugging."""
cfg = get_config()
settings = cfg.llm
primary = asdict(settings.primary)
if primary.get("api_key"):
primary["api_key"] = "***"
ensemble = []
for endpoint in settings.ensemble:
record = asdict(endpoint)
if record.get("api_key"):
record["api_key"] = "***"
ensemble.append(record)
return {
"strategy": settings.strategy,
"majority_threshold": settings.majority_threshold,
"primary": primary,
"ensemble": ensemble,
"providers": {
key: {
"base_url": provider.base_url,
"default_model": provider.default_model,
"enabled": provider.enabled,
}
for key, provider in cfg.llm_providers.items()
},
}