124 lines
3.9 KiB
Python
124 lines
3.9 KiB
Python
"""In-memory response cache for LLM calls."""
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import os
|
|
from collections import OrderedDict
|
|
from copy import deepcopy
|
|
from threading import Lock
|
|
from typing import Any, Callable, Mapping, Optional, Sequence
|
|
|
|
from time import monotonic
|
|
|
|
DEFAULT_CACHE_MAX_SIZE = int(os.getenv("LLM_CACHE_MAX_SIZE", "512") or 0)
|
|
DEFAULT_CACHE_TTL = float(os.getenv("LLM_CACHE_DEFAULT_TTL", "180") or 0.0)
|
|
_GLOBAL_CACHE: "LLMResponseCache" | None = None
|
|
|
|
|
|
def _normalize(obj: Any) -> Any:
|
|
if isinstance(obj, Mapping):
|
|
return {str(key): _normalize(value) for key, value in sorted(obj.items(), key=lambda item: str(item[0]))}
|
|
if isinstance(obj, (list, tuple)):
|
|
return [_normalize(item) for item in obj]
|
|
if isinstance(obj, (str, int, float, bool)) or obj is None:
|
|
return obj
|
|
return str(obj)
|
|
|
|
|
|
class LLMResponseCache:
|
|
"""Simple thread-safe LRU cache with TTL support."""
|
|
|
|
def __init__(
|
|
self,
|
|
max_size: int = DEFAULT_CACHE_MAX_SIZE,
|
|
default_ttl: float = DEFAULT_CACHE_TTL,
|
|
*,
|
|
time_func: Callable[[], float] = monotonic,
|
|
) -> None:
|
|
self._max_size = max(0, int(max_size))
|
|
self._default_ttl = max(0.0, float(default_ttl))
|
|
self._time = time_func
|
|
self._lock = Lock()
|
|
self._store: OrderedDict[str, tuple[float, Any]] = OrderedDict()
|
|
|
|
@property
|
|
def enabled(self) -> bool:
|
|
return self._max_size > 0 and self._default_ttl > 0
|
|
|
|
def get(self, key: str) -> Optional[Any]:
|
|
if not key or not self.enabled:
|
|
return None
|
|
with self._lock:
|
|
entry = self._store.get(key)
|
|
if not entry:
|
|
return None
|
|
expires_at, value = entry
|
|
if expires_at <= self._time():
|
|
self._store.pop(key, None)
|
|
return None
|
|
self._store.move_to_end(key)
|
|
return deepcopy(value)
|
|
|
|
def set(self, key: str, value: Any, *, ttl: Optional[float] = None) -> None:
|
|
if not key or not self.enabled:
|
|
return
|
|
ttl_value = self._default_ttl if ttl is None else float(ttl)
|
|
if ttl_value <= 0:
|
|
return
|
|
expires_at = self._time() + ttl_value
|
|
with self._lock:
|
|
self._store[key] = (expires_at, deepcopy(value))
|
|
self._store.move_to_end(key)
|
|
while len(self._store) > self._max_size:
|
|
self._store.popitem(last=False)
|
|
|
|
def clear(self) -> None:
|
|
with self._lock:
|
|
self._store.clear()
|
|
|
|
|
|
def llm_cache() -> LLMResponseCache:
|
|
global _GLOBAL_CACHE
|
|
if _GLOBAL_CACHE is None:
|
|
_GLOBAL_CACHE = LLMResponseCache()
|
|
return _GLOBAL_CACHE
|
|
|
|
|
|
def build_cache_key(
|
|
provider_key: str,
|
|
resolved_endpoint: Mapping[str, Any],
|
|
messages: Sequence[Mapping[str, Any]],
|
|
tools: Optional[Sequence[Mapping[str, Any]]],
|
|
tool_choice: Any,
|
|
) -> str:
|
|
payload = {
|
|
"provider": provider_key,
|
|
"model": resolved_endpoint.get("model"),
|
|
"base_url": resolved_endpoint.get("base_url"),
|
|
"temperature": resolved_endpoint.get("temperature"),
|
|
"mode": resolved_endpoint.get("mode"),
|
|
"messages": _normalize(messages),
|
|
"tools": _normalize(tools) if tools else None,
|
|
"tool_choice": _normalize(tool_choice),
|
|
}
|
|
raw = json.dumps(payload, ensure_ascii=False, sort_keys=True)
|
|
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
|
|
|
|
|
def is_cacheable(
|
|
resolved_endpoint: Mapping[str, Any],
|
|
messages: Sequence[Mapping[str, Any]],
|
|
tools: Optional[Sequence[Mapping[str, Any]]],
|
|
) -> bool:
|
|
if tools:
|
|
return False
|
|
if not messages:
|
|
return False
|
|
temperature = resolved_endpoint.get("temperature", 0.0)
|
|
try:
|
|
temperature_value = float(temperature)
|
|
except (TypeError, ValueError):
|
|
temperature_value = 0.0
|
|
return temperature_value <= 0.3
|