add rate limiting and response caching to LLM providers
This commit is contained in:
parent
7395c5acab
commit
ae1a49f79f
123
app/llm/cache.py
Normal file
123
app/llm/cache.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""In-memory response cache for LLM calls."""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Mapping, Optional, Sequence
|
||||
|
||||
from time import monotonic
|
||||
|
||||
DEFAULT_CACHE_MAX_SIZE = int(os.getenv("LLM_CACHE_MAX_SIZE", "512") or 0)
|
||||
DEFAULT_CACHE_TTL = float(os.getenv("LLM_CACHE_DEFAULT_TTL", "180") or 0.0)
|
||||
_GLOBAL_CACHE: "LLMResponseCache" | None = None
|
||||
|
||||
|
||||
def _normalize(obj: Any) -> Any:
|
||||
if isinstance(obj, Mapping):
|
||||
return {str(key): _normalize(value) for key, value in sorted(obj.items(), key=lambda item: str(item[0]))}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [_normalize(item) for item in obj]
|
||||
if isinstance(obj, (str, int, float, bool)) or obj is None:
|
||||
return obj
|
||||
return str(obj)
|
||||
|
||||
|
||||
class LLMResponseCache:
|
||||
"""Simple thread-safe LRU cache with TTL support."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = DEFAULT_CACHE_MAX_SIZE,
|
||||
default_ttl: float = DEFAULT_CACHE_TTL,
|
||||
*,
|
||||
time_func: Callable[[], float] = monotonic,
|
||||
) -> None:
|
||||
self._max_size = max(0, int(max_size))
|
||||
self._default_ttl = max(0.0, float(default_ttl))
|
||||
self._time = time_func
|
||||
self._lock = Lock()
|
||||
self._store: OrderedDict[str, tuple[float, Any]] = OrderedDict()
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self._max_size > 0 and self._default_ttl > 0
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
if not key or not self.enabled:
|
||||
return None
|
||||
with self._lock:
|
||||
entry = self._store.get(key)
|
||||
if not entry:
|
||||
return None
|
||||
expires_at, value = entry
|
||||
if expires_at <= self._time():
|
||||
self._store.pop(key, None)
|
||||
return None
|
||||
self._store.move_to_end(key)
|
||||
return deepcopy(value)
|
||||
|
||||
def set(self, key: str, value: Any, *, ttl: Optional[float] = None) -> None:
|
||||
if not key or not self.enabled:
|
||||
return
|
||||
ttl_value = self._default_ttl if ttl is None else float(ttl)
|
||||
if ttl_value <= 0:
|
||||
return
|
||||
expires_at = self._time() + ttl_value
|
||||
with self._lock:
|
||||
self._store[key] = (expires_at, deepcopy(value))
|
||||
self._store.move_to_end(key)
|
||||
while len(self._store) > self._max_size:
|
||||
self._store.popitem(last=False)
|
||||
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
self._store.clear()
|
||||
|
||||
|
||||
def llm_cache() -> LLMResponseCache:
|
||||
global _GLOBAL_CACHE
|
||||
if _GLOBAL_CACHE is None:
|
||||
_GLOBAL_CACHE = LLMResponseCache()
|
||||
return _GLOBAL_CACHE
|
||||
|
||||
|
||||
def build_cache_key(
|
||||
provider_key: str,
|
||||
resolved_endpoint: Mapping[str, Any],
|
||||
messages: Sequence[Mapping[str, Any]],
|
||||
tools: Optional[Sequence[Mapping[str, Any]]],
|
||||
tool_choice: Any,
|
||||
) -> str:
|
||||
payload = {
|
||||
"provider": provider_key,
|
||||
"model": resolved_endpoint.get("model"),
|
||||
"base_url": resolved_endpoint.get("base_url"),
|
||||
"temperature": resolved_endpoint.get("temperature"),
|
||||
"mode": resolved_endpoint.get("mode"),
|
||||
"messages": _normalize(messages),
|
||||
"tools": _normalize(tools) if tools else None,
|
||||
"tool_choice": _normalize(tool_choice),
|
||||
}
|
||||
raw = json.dumps(payload, ensure_ascii=False, sort_keys=True)
|
||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def is_cacheable(
|
||||
resolved_endpoint: Mapping[str, Any],
|
||||
messages: Sequence[Mapping[str, Any]],
|
||||
tools: Optional[Sequence[Mapping[str, Any]]],
|
||||
) -> bool:
|
||||
if tools:
|
||||
return False
|
||||
if not messages:
|
||||
return False
|
||||
temperature = resolved_endpoint.get("temperature", 0.0)
|
||||
try:
|
||||
temperature_value = float(temperature)
|
||||
except (TypeError, ValueError):
|
||||
temperature_value = 0.0
|
||||
return temperature_value <= 0.3
|
||||
@ -12,21 +12,25 @@ import requests
|
||||
from .context import ContextManager, Message
|
||||
from .templates import TemplateRegistry
|
||||
from .cost import configure_cost_limits, get_cost_controller, budget_available
|
||||
from .cache import build_cache_key, is_cacheable, llm_cache
|
||||
from .rate_limit import RateLimiter
|
||||
|
||||
from app.utils.config import (
|
||||
DEFAULT_LLM_BASE_URLS,
|
||||
DEFAULT_LLM_MODELS,
|
||||
DEFAULT_LLM_TEMPERATURES,
|
||||
DEFAULT_LLM_TIMEOUTS,
|
||||
DEFAULT_LLM_MODEL_OPTIONS,
|
||||
LLMConfig,
|
||||
LLMEndpoint,
|
||||
get_config,
|
||||
)
|
||||
from app.llm.metrics import record_call, record_template_usage
|
||||
from app.llm.metrics import record_call, record_cache_hit, record_template_usage
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOG_EXTRA = {"stage": "llm"}
|
||||
RATE_LIMITER = RateLimiter()
|
||||
|
||||
class LLMError(RuntimeError):
|
||||
"""Raised when LLM provider returns an error response."""
|
||||
@ -122,6 +126,17 @@ def resolve_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
|
||||
timeout = endpoint.timeout
|
||||
prompt_template = endpoint.prompt_template
|
||||
|
||||
def _safe_int(value: object, fallback: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return fallback
|
||||
|
||||
rate_limit_per_minute = 0
|
||||
rate_limit_burst = 0
|
||||
cache_enabled = True
|
||||
cache_ttl_seconds = 0
|
||||
|
||||
if provider_cfg:
|
||||
if not provider_cfg.enabled:
|
||||
raise LLMError(f"Provider {provider_key} 已被禁用")
|
||||
@ -134,6 +149,15 @@ def resolve_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
|
||||
timeout = provider_cfg.default_timeout
|
||||
prompt_template = prompt_template or (provider_cfg.prompt_template or None)
|
||||
mode = provider_cfg.mode or ("ollama" if provider_key == "ollama" else "openai")
|
||||
rate_limit_per_minute = max(0, _safe_int(provider_cfg.rate_limit_per_minute, 0))
|
||||
rate_limit_burst = provider_cfg.rate_limit_burst
|
||||
rate_limit_burst = _safe_int(rate_limit_burst, rate_limit_per_minute or 0)
|
||||
if rate_limit_per_minute > 0:
|
||||
rate_limit_burst = max(1, rate_limit_burst or rate_limit_per_minute)
|
||||
else:
|
||||
rate_limit_burst = max(0, rate_limit_burst)
|
||||
cache_enabled = bool(provider_cfg.cache_enabled)
|
||||
cache_ttl_seconds = max(0, _safe_int(provider_cfg.cache_ttl_seconds, 0))
|
||||
else:
|
||||
base_url = base_url or _default_base_url(provider_key)
|
||||
model = model or _default_model(provider_key)
|
||||
@ -143,6 +167,15 @@ def resolve_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
|
||||
if timeout is None:
|
||||
timeout = DEFAULT_LLM_TIMEOUTS.get(provider_key, 30.0)
|
||||
mode = "ollama" if provider_key == "ollama" else "openai"
|
||||
defaults = DEFAULT_LLM_MODEL_OPTIONS.get(provider_key, {})
|
||||
rate_limit_per_minute = max(0, _safe_int(defaults.get("rate_limit_per_minute"), 0))
|
||||
rate_limit_burst = _safe_int(defaults.get("rate_limit_burst"), rate_limit_per_minute or 0)
|
||||
if rate_limit_per_minute > 0:
|
||||
rate_limit_burst = max(1, rate_limit_burst or rate_limit_per_minute)
|
||||
else:
|
||||
rate_limit_burst = max(0, rate_limit_burst)
|
||||
cache_enabled = bool(defaults.get("cache_enabled", True))
|
||||
cache_ttl_seconds = max(0, _safe_int(defaults.get("cache_ttl_seconds"), 0))
|
||||
|
||||
return {
|
||||
"provider_key": provider_key,
|
||||
@ -153,6 +186,10 @@ def resolve_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
|
||||
"temperature": max(0.0, min(float(temperature), 2.0)),
|
||||
"timeout": max(5.0, float(timeout)),
|
||||
"prompt_template": prompt_template,
|
||||
"rate_limit_per_minute": rate_limit_per_minute,
|
||||
"rate_limit_burst": rate_limit_burst,
|
||||
"cache_enabled": cache_enabled,
|
||||
"cache_ttl_seconds": cache_ttl_seconds,
|
||||
}
|
||||
|
||||
|
||||
@ -201,6 +238,38 @@ def call_endpoint_with_messages(
|
||||
temperature = resolved["temperature"]
|
||||
timeout = resolved["timeout"]
|
||||
api_key = resolved["api_key"]
|
||||
rate_limit_per_minute = max(0, int(resolved.get("rate_limit_per_minute") or 0))
|
||||
rate_limit_burst = max(0, int(resolved.get("rate_limit_burst") or 0))
|
||||
cache_enabled = bool(resolved.get("cache_enabled", True))
|
||||
cache_ttl_seconds = max(0, int(resolved.get("cache_ttl_seconds") or 0))
|
||||
|
||||
if rate_limit_per_minute > 0:
|
||||
if rate_limit_burst <= 0:
|
||||
rate_limit_burst = rate_limit_per_minute
|
||||
wait_time = RATE_LIMITER.acquire(provider_key, rate_limit_per_minute, rate_limit_burst)
|
||||
if wait_time > 0:
|
||||
LOGGER.debug(
|
||||
"LLM 请求触发限速:provider=%s wait=%.3fs",
|
||||
provider_key,
|
||||
wait_time,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
|
||||
cache_store = llm_cache()
|
||||
cache_allowed = (
|
||||
cache_enabled
|
||||
and cache_ttl_seconds > 0
|
||||
and cache_store.enabled
|
||||
and is_cacheable(resolved, messages, tools)
|
||||
)
|
||||
cache_key: Optional[str] = None
|
||||
if cache_allowed:
|
||||
cache_key = build_cache_key(provider_key, resolved, messages, tools, tool_choice)
|
||||
cached_payload = cache_store.get(cache_key)
|
||||
if cached_payload is not None:
|
||||
record_cache_hit(provider_key, model)
|
||||
return cached_payload
|
||||
|
||||
cfg = get_config()
|
||||
cost_cfg = getattr(cfg, "llm_cost", None)
|
||||
@ -261,6 +330,8 @@ def call_endpoint_with_messages(
|
||||
# Ollama may return `tool_calls` under message.tool_calls when tools are used.
|
||||
# Return the raw response so callers can handle either OpenAI-like responses or
|
||||
# Ollama's message structure with `tool_calls`.
|
||||
if cache_allowed and cache_key:
|
||||
cache_store.set(cache_key, data, ttl=cache_ttl_seconds)
|
||||
return data
|
||||
|
||||
if not api_key:
|
||||
@ -298,6 +369,8 @@ def call_endpoint_with_messages(
|
||||
model,
|
||||
extra=LOG_EXTRA,
|
||||
)
|
||||
if cache_allowed and cache_key:
|
||||
cache_store.set(cache_key, data, ttl=cache_ttl_seconds)
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ class _Metrics:
|
||||
total_calls: int = 0
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
cache_hits: int = 0
|
||||
provider_calls: Dict[str, int] = field(default_factory=dict)
|
||||
model_calls: Dict[str, int] = field(default_factory=dict)
|
||||
decisions: Deque[Dict[str, object]] = field(default_factory=lambda: deque(maxlen=500))
|
||||
@ -62,6 +63,20 @@ def record_call(
|
||||
_notify_listeners()
|
||||
|
||||
|
||||
def record_cache_hit(provider: str, model: Optional[str] = None) -> None:
|
||||
"""Record a cache-hit event for observability."""
|
||||
|
||||
normalized_provider = (provider or "unknown").lower()
|
||||
normalized_model = (model or "").strip()
|
||||
with _LOCK:
|
||||
_METRICS.cache_hits += 1
|
||||
if normalized_provider:
|
||||
_METRICS.provider_calls.setdefault(normalized_provider, _METRICS.provider_calls.get(normalized_provider, 0))
|
||||
if normalized_model:
|
||||
_METRICS.model_calls.setdefault(normalized_model, _METRICS.model_calls.get(normalized_model, 0))
|
||||
_notify_listeners()
|
||||
|
||||
|
||||
def snapshot(reset: bool = False) -> Dict[str, object]:
|
||||
"""Return a snapshot of current metrics. Optionally reset counters."""
|
||||
|
||||
@ -70,6 +85,7 @@ def snapshot(reset: bool = False) -> Dict[str, object]:
|
||||
"total_calls": _METRICS.total_calls,
|
||||
"total_prompt_tokens": _METRICS.total_prompt_tokens,
|
||||
"total_completion_tokens": _METRICS.total_completion_tokens,
|
||||
"cache_hits": _METRICS.cache_hits,
|
||||
"provider_calls": dict(_METRICS.provider_calls),
|
||||
"model_calls": dict(_METRICS.model_calls),
|
||||
"decision_action_counts": dict(_METRICS.decision_action_counts),
|
||||
@ -86,6 +102,7 @@ def snapshot(reset: bool = False) -> Dict[str, object]:
|
||||
_METRICS.total_calls = 0
|
||||
_METRICS.total_prompt_tokens = 0
|
||||
_METRICS.total_completion_tokens = 0
|
||||
_METRICS.cache_hits = 0
|
||||
_METRICS.provider_calls.clear()
|
||||
_METRICS.model_calls.clear()
|
||||
_METRICS.decision_action_counts.clear()
|
||||
|
||||
48
app/llm/rate_limit.py
Normal file
48
app/llm/rate_limit.py
Normal file
@ -0,0 +1,48 @@
|
||||
"""Simple token-bucket rate limiter for LLM calls."""
|
||||
from __future__ import annotations
|
||||
|
||||
from threading import Lock
|
||||
from time import monotonic
|
||||
from typing import Callable, Dict
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Token bucket rate limiter that returns required wait time."""
|
||||
|
||||
def __init__(self, monotonic_func: Callable[[], float] | None = None) -> None:
|
||||
self._now = monotonic_func or monotonic
|
||||
self._lock = Lock()
|
||||
self._buckets: Dict[str, dict[str, float]] = {}
|
||||
|
||||
def acquire(self, key: str, rate_per_minute: int, burst: int) -> float:
|
||||
"""Attempt to consume a token; return wait time if throttled."""
|
||||
|
||||
if rate_per_minute <= 0:
|
||||
return 0.0
|
||||
capacity = float(max(1, burst if burst > 0 else rate_per_minute))
|
||||
rate = float(rate_per_minute)
|
||||
now = self._now()
|
||||
with self._lock:
|
||||
bucket = self._buckets.get(key)
|
||||
if bucket is None:
|
||||
bucket = {"tokens": capacity, "capacity": capacity, "last": now, "rate": rate}
|
||||
self._buckets[key] = bucket
|
||||
else:
|
||||
bucket["capacity"] = capacity
|
||||
bucket["rate"] = rate
|
||||
tokens = bucket["tokens"]
|
||||
elapsed = max(0.0, now - bucket["last"])
|
||||
tokens = min(capacity, tokens + elapsed * rate / 60.0)
|
||||
if tokens >= 1.0:
|
||||
bucket["tokens"] = tokens - 1.0
|
||||
bucket["last"] = now
|
||||
return 0.0
|
||||
bucket["tokens"] = tokens
|
||||
bucket["last"] = now
|
||||
deficit = 1.0 - tokens
|
||||
wait_time = deficit * 60.0 / rate
|
||||
return max(wait_time, 0.0)
|
||||
|
||||
def reset(self) -> None:
|
||||
with self._lock:
|
||||
self._buckets.clear()
|
||||
@ -230,6 +230,41 @@ def render_llm_settings() -> None:
|
||||
|
||||
enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key)
|
||||
mode_val = st.selectbox("模式", options=["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key)
|
||||
rate_key = f"provider_rate_{selected_provider}"
|
||||
burst_key = f"provider_burst_{selected_provider}"
|
||||
cache_enabled_key = f"provider_cache_enabled_{selected_provider}"
|
||||
cache_ttl_key = f"provider_cache_ttl_{selected_provider}"
|
||||
col_rate, col_burst = st.columns(2)
|
||||
with col_rate:
|
||||
rate_limit_val = st.number_input(
|
||||
"限速 (次/分钟)",
|
||||
min_value=0,
|
||||
max_value=5000,
|
||||
value=int(provider_cfg.rate_limit_per_minute or 0),
|
||||
step=10,
|
||||
key=rate_key,
|
||||
help="0 表示不限制请求频率,适合本地或私有部署。",
|
||||
)
|
||||
with col_burst:
|
||||
burst_limit_val = st.number_input(
|
||||
"突发令牌数",
|
||||
min_value=0,
|
||||
max_value=5000,
|
||||
value=int(provider_cfg.rate_limit_burst or max(1, provider_cfg.rate_limit_per_minute or 1)),
|
||||
step=5,
|
||||
key=burst_key,
|
||||
help="控制瞬时突发的最大请求数,建议不低于限速值。",
|
||||
)
|
||||
cache_enabled_val = st.checkbox("启用响应缓存", value=provider_cfg.cache_enabled, key=cache_enabled_key)
|
||||
cache_ttl_val = st.number_input(
|
||||
"缓存有效期(秒)",
|
||||
min_value=0,
|
||||
max_value=3600,
|
||||
value=int(provider_cfg.cache_ttl_seconds or 0),
|
||||
step=30,
|
||||
key=cache_ttl_key,
|
||||
help="缓存相同请求的返回结果以降低成本;0 表示禁用。",
|
||||
)
|
||||
st.markdown("可用模型:")
|
||||
if provider_cfg.models:
|
||||
st.code("\n".join(provider_cfg.models), language="text")
|
||||
@ -267,6 +302,19 @@ def render_llm_settings() -> None:
|
||||
provider_cfg.api_key = api_val.strip() or None
|
||||
provider_cfg.enabled = enabled_val
|
||||
provider_cfg.mode = mode_val
|
||||
try:
|
||||
provider_cfg.rate_limit_per_minute = max(0, int(rate_limit_val))
|
||||
except (TypeError, ValueError):
|
||||
provider_cfg.rate_limit_per_minute = 0
|
||||
try:
|
||||
provider_cfg.rate_limit_burst = max(0, int(burst_limit_val))
|
||||
except (TypeError, ValueError):
|
||||
provider_cfg.rate_limit_burst = provider_cfg.rate_limit_per_minute or 0
|
||||
provider_cfg.cache_enabled = bool(cache_enabled_val)
|
||||
try:
|
||||
provider_cfg.cache_ttl_seconds = max(0, int(cache_ttl_val))
|
||||
except (TypeError, ValueError):
|
||||
provider_cfg.cache_ttl_seconds = 0
|
||||
providers[selected_provider] = provider_cfg
|
||||
cfg.llm_providers = providers
|
||||
cfg.sync_runtime_llm()
|
||||
|
||||
@ -113,24 +113,40 @@ DEFAULT_LLM_MODEL_OPTIONS: Dict[str, Dict[str, object]] = {
|
||||
"base_url": "http://localhost:11434",
|
||||
"temperature": 0.2,
|
||||
"timeout": 30.0,
|
||||
"rate_limit_per_minute": 120,
|
||||
"rate_limit_burst": 40,
|
||||
"cache_enabled": True,
|
||||
"cache_ttl_seconds": 120,
|
||||
},
|
||||
"openai": {
|
||||
"models": ["gpt-4o-mini", "gpt-4.1-mini", "gpt-3.5-turbo"],
|
||||
"base_url": "https://api.openai.com",
|
||||
"temperature": 0.2,
|
||||
"timeout": 30.0,
|
||||
"rate_limit_per_minute": 60,
|
||||
"rate_limit_burst": 30,
|
||||
"cache_enabled": True,
|
||||
"cache_ttl_seconds": 180,
|
||||
},
|
||||
"deepseek": {
|
||||
"models": ["deepseek-chat", "deepseek-coder"],
|
||||
"base_url": "https://api.deepseek.com",
|
||||
"temperature": 0.2,
|
||||
"timeout": 45.0,
|
||||
"rate_limit_per_minute": 45,
|
||||
"rate_limit_burst": 20,
|
||||
"cache_enabled": True,
|
||||
"cache_ttl_seconds": 240,
|
||||
},
|
||||
"wenxin": {
|
||||
"models": ["ERNIE-Speed", "ERNIE-Bot"],
|
||||
"base_url": "https://aip.baidubce.com",
|
||||
"temperature": 0.2,
|
||||
"timeout": 60.0,
|
||||
"rate_limit_per_minute": 30,
|
||||
"rate_limit_burst": 15,
|
||||
"cache_enabled": True,
|
||||
"cache_ttl_seconds": 300,
|
||||
},
|
||||
}
|
||||
|
||||
@ -173,6 +189,10 @@ class LLMProvider:
|
||||
prompt_template: str = ""
|
||||
enabled: bool = True
|
||||
mode: str = "openai" # openai 或 ollama
|
||||
rate_limit_per_minute: int = 60
|
||||
rate_limit_burst: int = 30
|
||||
cache_enabled: bool = True
|
||||
cache_ttl_seconds: int = 180
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
return {
|
||||
@ -186,6 +206,10 @@ class LLMProvider:
|
||||
"prompt_template": self.prompt_template,
|
||||
"enabled": self.enabled,
|
||||
"mode": self.mode,
|
||||
"rate_limit_per_minute": self.rate_limit_per_minute,
|
||||
"rate_limit_burst": self.rate_limit_burst,
|
||||
"cache_enabled": self.cache_enabled,
|
||||
"cache_ttl_seconds": self.cache_ttl_seconds,
|
||||
}
|
||||
|
||||
|
||||
@ -291,6 +315,10 @@ def _default_llm_providers() -> Dict[str, LLMProvider]:
|
||||
default_temperature=float(meta.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))),
|
||||
default_timeout=float(meta.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))),
|
||||
mode=mode,
|
||||
rate_limit_per_minute=int(meta.get("rate_limit_per_minute", 60) or 0),
|
||||
rate_limit_burst=int(meta.get("rate_limit_burst", meta.get("rate_limit_per_minute", 60)) or 0),
|
||||
cache_enabled=bool(meta.get("cache_enabled", True)),
|
||||
cache_ttl_seconds=int(meta.get("cache_ttl_seconds", 180) or 0),
|
||||
)
|
||||
return providers
|
||||
|
||||
@ -619,6 +647,7 @@ def _load_from_file(cfg: AppConfig) -> None:
|
||||
for key, data in providers_payload.items():
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
provider_key = str(key).lower()
|
||||
models_raw = data.get("models")
|
||||
if isinstance(models_raw, str):
|
||||
models = [item.strip() for item in models_raw.split(',') if item.strip()]
|
||||
@ -626,8 +655,23 @@ def _load_from_file(cfg: AppConfig) -> None:
|
||||
models = [str(item).strip() for item in models_raw if str(item).strip()]
|
||||
else:
|
||||
models = []
|
||||
defaults = DEFAULT_LLM_MODEL_OPTIONS.get(provider_key, {})
|
||||
def _safe_int(value: object, fallback: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return fallback
|
||||
rate_limit_per_minute = _safe_int(data.get("rate_limit_per_minute"), int(defaults.get("rate_limit_per_minute", 60) or 0))
|
||||
rate_limit_burst = _safe_int(
|
||||
data.get("rate_limit_burst"),
|
||||
int(defaults.get("rate_limit_burst", defaults.get("rate_limit_per_minute", rate_limit_per_minute)) or rate_limit_per_minute or 0),
|
||||
)
|
||||
cache_ttl_seconds = _safe_int(
|
||||
data.get("cache_ttl_seconds"),
|
||||
int(defaults.get("cache_ttl_seconds", 180) or 0),
|
||||
)
|
||||
provider = LLMProvider(
|
||||
key=str(key).lower(),
|
||||
key=provider_key,
|
||||
title=str(data.get("title") or ""),
|
||||
base_url=str(data.get("base_url") or ""),
|
||||
api_key=data.get("api_key"),
|
||||
@ -637,7 +681,11 @@ def _load_from_file(cfg: AppConfig) -> None:
|
||||
default_timeout=float(data.get("default_timeout", 30.0)),
|
||||
prompt_template=str(data.get("prompt_template") or ""),
|
||||
enabled=bool(data.get("enabled", True)),
|
||||
mode=str(data.get("mode") or ("ollama" if str(key).lower() == "ollama" else "openai")),
|
||||
mode=str(data.get("mode") or ("ollama" if provider_key == "ollama" else "openai")),
|
||||
rate_limit_per_minute=max(0, rate_limit_per_minute),
|
||||
rate_limit_burst=max(1, rate_limit_burst) if rate_limit_per_minute > 0 else max(0, rate_limit_burst),
|
||||
cache_enabled=bool(data.get("cache_enabled", defaults.get("cache_enabled", True))),
|
||||
cache_ttl_seconds=max(0, cache_ttl_seconds),
|
||||
)
|
||||
providers[provider.key] = provider
|
||||
if providers:
|
||||
|
||||
@ -29,10 +29,10 @@
|
||||
| 工作项 | 状态 | 说明 |
|
||||
| --- | --- | --- |
|
||||
| Provider 与 function 架构 | ✅ | Provider 管理、function-calling 降级与重试策略已收敛。 |
|
||||
| 提示模板治理 | 🔄 | LLM 设置新增模板版本治理与使用监控,后续补充成本/效果数据。 |
|
||||
| 部门遥测可视化 | 🔄 | LLM 设置新增遥测面板,支持分页查看/导出部门 & 全局遥测。 |
|
||||
| 多轮逻辑博弈框架 | 🔄 | 新增主持 briefing、预测对齐及冲突复核轮,持续完善信念修正策略。 |
|
||||
| LLM 稳定性提升 | ⏳ | 持续优化限速、降级、成本控制与缓存策略。 |
|
||||
| 提示模板治理 | ✅ | LLM 设置提供模板版本治理、元数据维护与调用监控,待补充成本/效果分析。 |
|
||||
| 部门遥测可视化 | ✅ | 设置页已支持部门/全局遥测筛选、导出与动态限额调节。 |
|
||||
| 多轮逻辑博弈框架 | ✅ | 主持 briefing、预测对齐、风险复核与冲突回合均已串联上线。 |
|
||||
| LLM 稳定性提升 | ✅ | Provider 级限速、响应缓存与成本守卫协同生效,支撑平滑降级策略。 |
|
||||
|
||||
## UI 与监控
|
||||
|
||||
|
||||
49
tests/test_llm_runtime.py
Normal file
49
tests/test_llm_runtime.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""Tests for LLM runtime helpers such as rate limiting and caching."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.llm.cache import LLMResponseCache
|
||||
from app.llm.rate_limit import RateLimiter
|
||||
|
||||
|
||||
def test_rate_limiter_returns_wait_time() -> None:
|
||||
"""Ensure limiter enforces configured throughput."""
|
||||
|
||||
current = [0.0]
|
||||
|
||||
def fake_time() -> float:
|
||||
return current[0]
|
||||
|
||||
limiter = RateLimiter(monotonic_func=fake_time)
|
||||
|
||||
assert limiter.acquire("openai", rate_per_minute=2, burst=1) == pytest.approx(0.0)
|
||||
delay = limiter.acquire("openai", rate_per_minute=2, burst=1)
|
||||
assert delay == pytest.approx(30.0, rel=1e-3)
|
||||
current[0] += 30.0
|
||||
assert limiter.acquire("openai", rate_per_minute=2, burst=1) == pytest.approx(0.0)
|
||||
|
||||
|
||||
def test_llm_response_cache_ttl_and_lru() -> None:
|
||||
"""Validate cache expiration and eviction semantics."""
|
||||
|
||||
current = [0.0]
|
||||
|
||||
def fake_time() -> float:
|
||||
return current[0]
|
||||
|
||||
cache = LLMResponseCache(max_size=2, default_ttl=10, time_func=fake_time)
|
||||
|
||||
cache.set("key1", {"value": 1})
|
||||
assert cache.get("key1") == {"value": 1}
|
||||
|
||||
current[0] += 11
|
||||
assert cache.get("key1") is None
|
||||
|
||||
cache.set("key1", {"value": 1})
|
||||
cache.set("key2", {"value": 2})
|
||||
assert cache.get("key1") == {"value": 1}
|
||||
cache.set("key3", {"value": 3})
|
||||
assert cache.get("key2") is None
|
||||
assert cache.get("key1") == {"value": 1}
|
||||
assert cache.get("key3") == {"value": 3}
|
||||
Loading…
Reference in New Issue
Block a user