add rate limiting and response caching to LLM providers

This commit is contained in:
sam 2025-10-17 10:13:27 +08:00
parent 7395c5acab
commit ae1a49f79f
8 changed files with 413 additions and 7 deletions

123
app/llm/cache.py Normal file
View File

@ -0,0 +1,123 @@
"""In-memory response cache for LLM calls."""
from __future__ import annotations
import hashlib
import json
import os
from collections import OrderedDict
from copy import deepcopy
from threading import Lock
from typing import Any, Callable, Mapping, Optional, Sequence
from time import monotonic
DEFAULT_CACHE_MAX_SIZE = int(os.getenv("LLM_CACHE_MAX_SIZE", "512") or 0)
DEFAULT_CACHE_TTL = float(os.getenv("LLM_CACHE_DEFAULT_TTL", "180") or 0.0)
_GLOBAL_CACHE: "LLMResponseCache" | None = None
def _normalize(obj: Any) -> Any:
if isinstance(obj, Mapping):
return {str(key): _normalize(value) for key, value in sorted(obj.items(), key=lambda item: str(item[0]))}
if isinstance(obj, (list, tuple)):
return [_normalize(item) for item in obj]
if isinstance(obj, (str, int, float, bool)) or obj is None:
return obj
return str(obj)
class LLMResponseCache:
"""Simple thread-safe LRU cache with TTL support."""
def __init__(
self,
max_size: int = DEFAULT_CACHE_MAX_SIZE,
default_ttl: float = DEFAULT_CACHE_TTL,
*,
time_func: Callable[[], float] = monotonic,
) -> None:
self._max_size = max(0, int(max_size))
self._default_ttl = max(0.0, float(default_ttl))
self._time = time_func
self._lock = Lock()
self._store: OrderedDict[str, tuple[float, Any]] = OrderedDict()
@property
def enabled(self) -> bool:
return self._max_size > 0 and self._default_ttl > 0
def get(self, key: str) -> Optional[Any]:
if not key or not self.enabled:
return None
with self._lock:
entry = self._store.get(key)
if not entry:
return None
expires_at, value = entry
if expires_at <= self._time():
self._store.pop(key, None)
return None
self._store.move_to_end(key)
return deepcopy(value)
def set(self, key: str, value: Any, *, ttl: Optional[float] = None) -> None:
if not key or not self.enabled:
return
ttl_value = self._default_ttl if ttl is None else float(ttl)
if ttl_value <= 0:
return
expires_at = self._time() + ttl_value
with self._lock:
self._store[key] = (expires_at, deepcopy(value))
self._store.move_to_end(key)
while len(self._store) > self._max_size:
self._store.popitem(last=False)
def clear(self) -> None:
with self._lock:
self._store.clear()
def llm_cache() -> LLMResponseCache:
global _GLOBAL_CACHE
if _GLOBAL_CACHE is None:
_GLOBAL_CACHE = LLMResponseCache()
return _GLOBAL_CACHE
def build_cache_key(
provider_key: str,
resolved_endpoint: Mapping[str, Any],
messages: Sequence[Mapping[str, Any]],
tools: Optional[Sequence[Mapping[str, Any]]],
tool_choice: Any,
) -> str:
payload = {
"provider": provider_key,
"model": resolved_endpoint.get("model"),
"base_url": resolved_endpoint.get("base_url"),
"temperature": resolved_endpoint.get("temperature"),
"mode": resolved_endpoint.get("mode"),
"messages": _normalize(messages),
"tools": _normalize(tools) if tools else None,
"tool_choice": _normalize(tool_choice),
}
raw = json.dumps(payload, ensure_ascii=False, sort_keys=True)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def is_cacheable(
resolved_endpoint: Mapping[str, Any],
messages: Sequence[Mapping[str, Any]],
tools: Optional[Sequence[Mapping[str, Any]]],
) -> bool:
if tools:
return False
if not messages:
return False
temperature = resolved_endpoint.get("temperature", 0.0)
try:
temperature_value = float(temperature)
except (TypeError, ValueError):
temperature_value = 0.0
return temperature_value <= 0.3

View File

@ -12,21 +12,25 @@ import requests
from .context import ContextManager, Message from .context import ContextManager, Message
from .templates import TemplateRegistry from .templates import TemplateRegistry
from .cost import configure_cost_limits, get_cost_controller, budget_available from .cost import configure_cost_limits, get_cost_controller, budget_available
from .cache import build_cache_key, is_cacheable, llm_cache
from .rate_limit import RateLimiter
from app.utils.config import ( from app.utils.config import (
DEFAULT_LLM_BASE_URLS, DEFAULT_LLM_BASE_URLS,
DEFAULT_LLM_MODELS, DEFAULT_LLM_MODELS,
DEFAULT_LLM_TEMPERATURES, DEFAULT_LLM_TEMPERATURES,
DEFAULT_LLM_TIMEOUTS, DEFAULT_LLM_TIMEOUTS,
DEFAULT_LLM_MODEL_OPTIONS,
LLMConfig, LLMConfig,
LLMEndpoint, LLMEndpoint,
get_config, get_config,
) )
from app.llm.metrics import record_call, record_template_usage from app.llm.metrics import record_call, record_cache_hit, record_template_usage
from app.utils.logging import get_logger from app.utils.logging import get_logger
LOGGER = get_logger(__name__) LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "llm"} LOG_EXTRA = {"stage": "llm"}
RATE_LIMITER = RateLimiter()
class LLMError(RuntimeError): class LLMError(RuntimeError):
"""Raised when LLM provider returns an error response.""" """Raised when LLM provider returns an error response."""
@ -122,6 +126,17 @@ def resolve_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
timeout = endpoint.timeout timeout = endpoint.timeout
prompt_template = endpoint.prompt_template prompt_template = endpoint.prompt_template
def _safe_int(value: object, fallback: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return fallback
rate_limit_per_minute = 0
rate_limit_burst = 0
cache_enabled = True
cache_ttl_seconds = 0
if provider_cfg: if provider_cfg:
if not provider_cfg.enabled: if not provider_cfg.enabled:
raise LLMError(f"Provider {provider_key} 已被禁用") raise LLMError(f"Provider {provider_key} 已被禁用")
@ -134,6 +149,15 @@ def resolve_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
timeout = provider_cfg.default_timeout timeout = provider_cfg.default_timeout
prompt_template = prompt_template or (provider_cfg.prompt_template or None) prompt_template = prompt_template or (provider_cfg.prompt_template or None)
mode = provider_cfg.mode or ("ollama" if provider_key == "ollama" else "openai") mode = provider_cfg.mode or ("ollama" if provider_key == "ollama" else "openai")
rate_limit_per_minute = max(0, _safe_int(provider_cfg.rate_limit_per_minute, 0))
rate_limit_burst = provider_cfg.rate_limit_burst
rate_limit_burst = _safe_int(rate_limit_burst, rate_limit_per_minute or 0)
if rate_limit_per_minute > 0:
rate_limit_burst = max(1, rate_limit_burst or rate_limit_per_minute)
else:
rate_limit_burst = max(0, rate_limit_burst)
cache_enabled = bool(provider_cfg.cache_enabled)
cache_ttl_seconds = max(0, _safe_int(provider_cfg.cache_ttl_seconds, 0))
else: else:
base_url = base_url or _default_base_url(provider_key) base_url = base_url or _default_base_url(provider_key)
model = model or _default_model(provider_key) model = model or _default_model(provider_key)
@ -143,6 +167,15 @@ def resolve_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
if timeout is None: if timeout is None:
timeout = DEFAULT_LLM_TIMEOUTS.get(provider_key, 30.0) timeout = DEFAULT_LLM_TIMEOUTS.get(provider_key, 30.0)
mode = "ollama" if provider_key == "ollama" else "openai" mode = "ollama" if provider_key == "ollama" else "openai"
defaults = DEFAULT_LLM_MODEL_OPTIONS.get(provider_key, {})
rate_limit_per_minute = max(0, _safe_int(defaults.get("rate_limit_per_minute"), 0))
rate_limit_burst = _safe_int(defaults.get("rate_limit_burst"), rate_limit_per_minute or 0)
if rate_limit_per_minute > 0:
rate_limit_burst = max(1, rate_limit_burst or rate_limit_per_minute)
else:
rate_limit_burst = max(0, rate_limit_burst)
cache_enabled = bool(defaults.get("cache_enabled", True))
cache_ttl_seconds = max(0, _safe_int(defaults.get("cache_ttl_seconds"), 0))
return { return {
"provider_key": provider_key, "provider_key": provider_key,
@ -153,6 +186,10 @@ def resolve_endpoint(endpoint: LLMEndpoint) -> Dict[str, object]:
"temperature": max(0.0, min(float(temperature), 2.0)), "temperature": max(0.0, min(float(temperature), 2.0)),
"timeout": max(5.0, float(timeout)), "timeout": max(5.0, float(timeout)),
"prompt_template": prompt_template, "prompt_template": prompt_template,
"rate_limit_per_minute": rate_limit_per_minute,
"rate_limit_burst": rate_limit_burst,
"cache_enabled": cache_enabled,
"cache_ttl_seconds": cache_ttl_seconds,
} }
@ -201,6 +238,38 @@ def call_endpoint_with_messages(
temperature = resolved["temperature"] temperature = resolved["temperature"]
timeout = resolved["timeout"] timeout = resolved["timeout"]
api_key = resolved["api_key"] api_key = resolved["api_key"]
rate_limit_per_minute = max(0, int(resolved.get("rate_limit_per_minute") or 0))
rate_limit_burst = max(0, int(resolved.get("rate_limit_burst") or 0))
cache_enabled = bool(resolved.get("cache_enabled", True))
cache_ttl_seconds = max(0, int(resolved.get("cache_ttl_seconds") or 0))
if rate_limit_per_minute > 0:
if rate_limit_burst <= 0:
rate_limit_burst = rate_limit_per_minute
wait_time = RATE_LIMITER.acquire(provider_key, rate_limit_per_minute, rate_limit_burst)
if wait_time > 0:
LOGGER.debug(
"LLM 请求触发限速provider=%s wait=%.3fs",
provider_key,
wait_time,
extra=LOG_EXTRA,
)
time.sleep(wait_time)
cache_store = llm_cache()
cache_allowed = (
cache_enabled
and cache_ttl_seconds > 0
and cache_store.enabled
and is_cacheable(resolved, messages, tools)
)
cache_key: Optional[str] = None
if cache_allowed:
cache_key = build_cache_key(provider_key, resolved, messages, tools, tool_choice)
cached_payload = cache_store.get(cache_key)
if cached_payload is not None:
record_cache_hit(provider_key, model)
return cached_payload
cfg = get_config() cfg = get_config()
cost_cfg = getattr(cfg, "llm_cost", None) cost_cfg = getattr(cfg, "llm_cost", None)
@ -261,6 +330,8 @@ def call_endpoint_with_messages(
# Ollama may return `tool_calls` under message.tool_calls when tools are used. # Ollama may return `tool_calls` under message.tool_calls when tools are used.
# Return the raw response so callers can handle either OpenAI-like responses or # Return the raw response so callers can handle either OpenAI-like responses or
# Ollama's message structure with `tool_calls`. # Ollama's message structure with `tool_calls`.
if cache_allowed and cache_key:
cache_store.set(cache_key, data, ttl=cache_ttl_seconds)
return data return data
if not api_key: if not api_key:
@ -298,6 +369,8 @@ def call_endpoint_with_messages(
model, model,
extra=LOG_EXTRA, extra=LOG_EXTRA,
) )
if cache_allowed and cache_key:
cache_store.set(cache_key, data, ttl=cache_ttl_seconds)
return data return data

View File

@ -14,6 +14,7 @@ class _Metrics:
total_calls: int = 0 total_calls: int = 0
total_prompt_tokens: int = 0 total_prompt_tokens: int = 0
total_completion_tokens: int = 0 total_completion_tokens: int = 0
cache_hits: int = 0
provider_calls: Dict[str, int] = field(default_factory=dict) provider_calls: Dict[str, int] = field(default_factory=dict)
model_calls: Dict[str, int] = field(default_factory=dict) model_calls: Dict[str, int] = field(default_factory=dict)
decisions: Deque[Dict[str, object]] = field(default_factory=lambda: deque(maxlen=500)) decisions: Deque[Dict[str, object]] = field(default_factory=lambda: deque(maxlen=500))
@ -62,6 +63,20 @@ def record_call(
_notify_listeners() _notify_listeners()
def record_cache_hit(provider: str, model: Optional[str] = None) -> None:
"""Record a cache-hit event for observability."""
normalized_provider = (provider or "unknown").lower()
normalized_model = (model or "").strip()
with _LOCK:
_METRICS.cache_hits += 1
if normalized_provider:
_METRICS.provider_calls.setdefault(normalized_provider, _METRICS.provider_calls.get(normalized_provider, 0))
if normalized_model:
_METRICS.model_calls.setdefault(normalized_model, _METRICS.model_calls.get(normalized_model, 0))
_notify_listeners()
def snapshot(reset: bool = False) -> Dict[str, object]: def snapshot(reset: bool = False) -> Dict[str, object]:
"""Return a snapshot of current metrics. Optionally reset counters.""" """Return a snapshot of current metrics. Optionally reset counters."""
@ -70,6 +85,7 @@ def snapshot(reset: bool = False) -> Dict[str, object]:
"total_calls": _METRICS.total_calls, "total_calls": _METRICS.total_calls,
"total_prompt_tokens": _METRICS.total_prompt_tokens, "total_prompt_tokens": _METRICS.total_prompt_tokens,
"total_completion_tokens": _METRICS.total_completion_tokens, "total_completion_tokens": _METRICS.total_completion_tokens,
"cache_hits": _METRICS.cache_hits,
"provider_calls": dict(_METRICS.provider_calls), "provider_calls": dict(_METRICS.provider_calls),
"model_calls": dict(_METRICS.model_calls), "model_calls": dict(_METRICS.model_calls),
"decision_action_counts": dict(_METRICS.decision_action_counts), "decision_action_counts": dict(_METRICS.decision_action_counts),
@ -86,6 +102,7 @@ def snapshot(reset: bool = False) -> Dict[str, object]:
_METRICS.total_calls = 0 _METRICS.total_calls = 0
_METRICS.total_prompt_tokens = 0 _METRICS.total_prompt_tokens = 0
_METRICS.total_completion_tokens = 0 _METRICS.total_completion_tokens = 0
_METRICS.cache_hits = 0
_METRICS.provider_calls.clear() _METRICS.provider_calls.clear()
_METRICS.model_calls.clear() _METRICS.model_calls.clear()
_METRICS.decision_action_counts.clear() _METRICS.decision_action_counts.clear()

48
app/llm/rate_limit.py Normal file
View File

@ -0,0 +1,48 @@
"""Simple token-bucket rate limiter for LLM calls."""
from __future__ import annotations
from threading import Lock
from time import monotonic
from typing import Callable, Dict
class RateLimiter:
"""Token bucket rate limiter that returns required wait time."""
def __init__(self, monotonic_func: Callable[[], float] | None = None) -> None:
self._now = monotonic_func or monotonic
self._lock = Lock()
self._buckets: Dict[str, dict[str, float]] = {}
def acquire(self, key: str, rate_per_minute: int, burst: int) -> float:
"""Attempt to consume a token; return wait time if throttled."""
if rate_per_minute <= 0:
return 0.0
capacity = float(max(1, burst if burst > 0 else rate_per_minute))
rate = float(rate_per_minute)
now = self._now()
with self._lock:
bucket = self._buckets.get(key)
if bucket is None:
bucket = {"tokens": capacity, "capacity": capacity, "last": now, "rate": rate}
self._buckets[key] = bucket
else:
bucket["capacity"] = capacity
bucket["rate"] = rate
tokens = bucket["tokens"]
elapsed = max(0.0, now - bucket["last"])
tokens = min(capacity, tokens + elapsed * rate / 60.0)
if tokens >= 1.0:
bucket["tokens"] = tokens - 1.0
bucket["last"] = now
return 0.0
bucket["tokens"] = tokens
bucket["last"] = now
deficit = 1.0 - tokens
wait_time = deficit * 60.0 / rate
return max(wait_time, 0.0)
def reset(self) -> None:
with self._lock:
self._buckets.clear()

View File

@ -230,6 +230,41 @@ def render_llm_settings() -> None:
enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key) enabled_val = st.checkbox("启用", value=provider_cfg.enabled, key=enabled_key)
mode_val = st.selectbox("模式", options=["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key) mode_val = st.selectbox("模式", options=["openai", "ollama"], index=0 if provider_cfg.mode == "openai" else 1, key=mode_key)
rate_key = f"provider_rate_{selected_provider}"
burst_key = f"provider_burst_{selected_provider}"
cache_enabled_key = f"provider_cache_enabled_{selected_provider}"
cache_ttl_key = f"provider_cache_ttl_{selected_provider}"
col_rate, col_burst = st.columns(2)
with col_rate:
rate_limit_val = st.number_input(
"限速 (次/分钟)",
min_value=0,
max_value=5000,
value=int(provider_cfg.rate_limit_per_minute or 0),
step=10,
key=rate_key,
help="0 表示不限制请求频率,适合本地或私有部署。",
)
with col_burst:
burst_limit_val = st.number_input(
"突发令牌数",
min_value=0,
max_value=5000,
value=int(provider_cfg.rate_limit_burst or max(1, provider_cfg.rate_limit_per_minute or 1)),
step=5,
key=burst_key,
help="控制瞬时突发的最大请求数,建议不低于限速值。",
)
cache_enabled_val = st.checkbox("启用响应缓存", value=provider_cfg.cache_enabled, key=cache_enabled_key)
cache_ttl_val = st.number_input(
"缓存有效期(秒)",
min_value=0,
max_value=3600,
value=int(provider_cfg.cache_ttl_seconds or 0),
step=30,
key=cache_ttl_key,
help="缓存相同请求的返回结果以降低成本0 表示禁用。",
)
st.markdown("可用模型:") st.markdown("可用模型:")
if provider_cfg.models: if provider_cfg.models:
st.code("\n".join(provider_cfg.models), language="text") st.code("\n".join(provider_cfg.models), language="text")
@ -267,6 +302,19 @@ def render_llm_settings() -> None:
provider_cfg.api_key = api_val.strip() or None provider_cfg.api_key = api_val.strip() or None
provider_cfg.enabled = enabled_val provider_cfg.enabled = enabled_val
provider_cfg.mode = mode_val provider_cfg.mode = mode_val
try:
provider_cfg.rate_limit_per_minute = max(0, int(rate_limit_val))
except (TypeError, ValueError):
provider_cfg.rate_limit_per_minute = 0
try:
provider_cfg.rate_limit_burst = max(0, int(burst_limit_val))
except (TypeError, ValueError):
provider_cfg.rate_limit_burst = provider_cfg.rate_limit_per_minute or 0
provider_cfg.cache_enabled = bool(cache_enabled_val)
try:
provider_cfg.cache_ttl_seconds = max(0, int(cache_ttl_val))
except (TypeError, ValueError):
provider_cfg.cache_ttl_seconds = 0
providers[selected_provider] = provider_cfg providers[selected_provider] = provider_cfg
cfg.llm_providers = providers cfg.llm_providers = providers
cfg.sync_runtime_llm() cfg.sync_runtime_llm()

View File

@ -113,24 +113,40 @@ DEFAULT_LLM_MODEL_OPTIONS: Dict[str, Dict[str, object]] = {
"base_url": "http://localhost:11434", "base_url": "http://localhost:11434",
"temperature": 0.2, "temperature": 0.2,
"timeout": 30.0, "timeout": 30.0,
"rate_limit_per_minute": 120,
"rate_limit_burst": 40,
"cache_enabled": True,
"cache_ttl_seconds": 120,
}, },
"openai": { "openai": {
"models": ["gpt-4o-mini", "gpt-4.1-mini", "gpt-3.5-turbo"], "models": ["gpt-4o-mini", "gpt-4.1-mini", "gpt-3.5-turbo"],
"base_url": "https://api.openai.com", "base_url": "https://api.openai.com",
"temperature": 0.2, "temperature": 0.2,
"timeout": 30.0, "timeout": 30.0,
"rate_limit_per_minute": 60,
"rate_limit_burst": 30,
"cache_enabled": True,
"cache_ttl_seconds": 180,
}, },
"deepseek": { "deepseek": {
"models": ["deepseek-chat", "deepseek-coder"], "models": ["deepseek-chat", "deepseek-coder"],
"base_url": "https://api.deepseek.com", "base_url": "https://api.deepseek.com",
"temperature": 0.2, "temperature": 0.2,
"timeout": 45.0, "timeout": 45.0,
"rate_limit_per_minute": 45,
"rate_limit_burst": 20,
"cache_enabled": True,
"cache_ttl_seconds": 240,
}, },
"wenxin": { "wenxin": {
"models": ["ERNIE-Speed", "ERNIE-Bot"], "models": ["ERNIE-Speed", "ERNIE-Bot"],
"base_url": "https://aip.baidubce.com", "base_url": "https://aip.baidubce.com",
"temperature": 0.2, "temperature": 0.2,
"timeout": 60.0, "timeout": 60.0,
"rate_limit_per_minute": 30,
"rate_limit_burst": 15,
"cache_enabled": True,
"cache_ttl_seconds": 300,
}, },
} }
@ -173,6 +189,10 @@ class LLMProvider:
prompt_template: str = "" prompt_template: str = ""
enabled: bool = True enabled: bool = True
mode: str = "openai" # openai 或 ollama mode: str = "openai" # openai 或 ollama
rate_limit_per_minute: int = 60
rate_limit_burst: int = 30
cache_enabled: bool = True
cache_ttl_seconds: int = 180
def to_dict(self) -> Dict[str, object]: def to_dict(self) -> Dict[str, object]:
return { return {
@ -186,6 +206,10 @@ class LLMProvider:
"prompt_template": self.prompt_template, "prompt_template": self.prompt_template,
"enabled": self.enabled, "enabled": self.enabled,
"mode": self.mode, "mode": self.mode,
"rate_limit_per_minute": self.rate_limit_per_minute,
"rate_limit_burst": self.rate_limit_burst,
"cache_enabled": self.cache_enabled,
"cache_ttl_seconds": self.cache_ttl_seconds,
} }
@ -291,6 +315,10 @@ def _default_llm_providers() -> Dict[str, LLMProvider]:
default_temperature=float(meta.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))), default_temperature=float(meta.get("temperature", DEFAULT_LLM_TEMPERATURES.get(provider, 0.2))),
default_timeout=float(meta.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))), default_timeout=float(meta.get("timeout", DEFAULT_LLM_TIMEOUTS.get(provider, 30.0))),
mode=mode, mode=mode,
rate_limit_per_minute=int(meta.get("rate_limit_per_minute", 60) or 0),
rate_limit_burst=int(meta.get("rate_limit_burst", meta.get("rate_limit_per_minute", 60)) or 0),
cache_enabled=bool(meta.get("cache_enabled", True)),
cache_ttl_seconds=int(meta.get("cache_ttl_seconds", 180) or 0),
) )
return providers return providers
@ -619,6 +647,7 @@ def _load_from_file(cfg: AppConfig) -> None:
for key, data in providers_payload.items(): for key, data in providers_payload.items():
if not isinstance(data, dict): if not isinstance(data, dict):
continue continue
provider_key = str(key).lower()
models_raw = data.get("models") models_raw = data.get("models")
if isinstance(models_raw, str): if isinstance(models_raw, str):
models = [item.strip() for item in models_raw.split(',') if item.strip()] models = [item.strip() for item in models_raw.split(',') if item.strip()]
@ -626,8 +655,23 @@ def _load_from_file(cfg: AppConfig) -> None:
models = [str(item).strip() for item in models_raw if str(item).strip()] models = [str(item).strip() for item in models_raw if str(item).strip()]
else: else:
models = [] models = []
defaults = DEFAULT_LLM_MODEL_OPTIONS.get(provider_key, {})
def _safe_int(value: object, fallback: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return fallback
rate_limit_per_minute = _safe_int(data.get("rate_limit_per_minute"), int(defaults.get("rate_limit_per_minute", 60) or 0))
rate_limit_burst = _safe_int(
data.get("rate_limit_burst"),
int(defaults.get("rate_limit_burst", defaults.get("rate_limit_per_minute", rate_limit_per_minute)) or rate_limit_per_minute or 0),
)
cache_ttl_seconds = _safe_int(
data.get("cache_ttl_seconds"),
int(defaults.get("cache_ttl_seconds", 180) or 0),
)
provider = LLMProvider( provider = LLMProvider(
key=str(key).lower(), key=provider_key,
title=str(data.get("title") or ""), title=str(data.get("title") or ""),
base_url=str(data.get("base_url") or ""), base_url=str(data.get("base_url") or ""),
api_key=data.get("api_key"), api_key=data.get("api_key"),
@ -637,7 +681,11 @@ def _load_from_file(cfg: AppConfig) -> None:
default_timeout=float(data.get("default_timeout", 30.0)), default_timeout=float(data.get("default_timeout", 30.0)),
prompt_template=str(data.get("prompt_template") or ""), prompt_template=str(data.get("prompt_template") or ""),
enabled=bool(data.get("enabled", True)), enabled=bool(data.get("enabled", True)),
mode=str(data.get("mode") or ("ollama" if str(key).lower() == "ollama" else "openai")), mode=str(data.get("mode") or ("ollama" if provider_key == "ollama" else "openai")),
rate_limit_per_minute=max(0, rate_limit_per_minute),
rate_limit_burst=max(1, rate_limit_burst) if rate_limit_per_minute > 0 else max(0, rate_limit_burst),
cache_enabled=bool(data.get("cache_enabled", defaults.get("cache_enabled", True))),
cache_ttl_seconds=max(0, cache_ttl_seconds),
) )
providers[provider.key] = provider providers[provider.key] = provider
if providers: if providers:

View File

@ -29,10 +29,10 @@
| 工作项 | 状态 | 说明 | | 工作项 | 状态 | 说明 |
| --- | --- | --- | | --- | --- | --- |
| Provider 与 function 架构 | ✅ | Provider 管理、function-calling 降级与重试策略已收敛。 | | Provider 与 function 架构 | ✅ | Provider 管理、function-calling 降级与重试策略已收敛。 |
| 提示模板治理 | 🔄 | LLM 设置新增模板版本治理与使用监控,后续补充成本/效果数据。 | | 提示模板治理 | ✅ | LLM 设置提供模板版本治理、元数据维护与调用监控,待补充成本/效果分析。 |
| 部门遥测可视化 | 🔄 | LLM 设置新增遥测面板,支持分页查看/导出部门 & 全局遥测。 | | 部门遥测可视化 | ✅ | 设置页已支持部门/全局遥测筛选、导出与动态限额调节。 |
| 多轮逻辑博弈框架 | 🔄 | 新增主持 briefing、预测对齐及冲突复核轮持续完善信念修正策略。 | | 多轮逻辑博弈框架 | ✅ | 主持 briefing、预测对齐、风险复核与冲突回合均已串联上线。 |
| LLM 稳定性提升 | ⏳ | 持续优化限速、降级、成本控制与缓存策略。 | | LLM 稳定性提升 | ✅ | Provider 级限速、响应缓存与成本守卫协同生效,支撑平滑降级策略。 |
## UI 与监控 ## UI 与监控

49
tests/test_llm_runtime.py Normal file
View File

@ -0,0 +1,49 @@
"""Tests for LLM runtime helpers such as rate limiting and caching."""
from __future__ import annotations
import pytest
from app.llm.cache import LLMResponseCache
from app.llm.rate_limit import RateLimiter
def test_rate_limiter_returns_wait_time() -> None:
"""Ensure limiter enforces configured throughput."""
current = [0.0]
def fake_time() -> float:
return current[0]
limiter = RateLimiter(monotonic_func=fake_time)
assert limiter.acquire("openai", rate_per_minute=2, burst=1) == pytest.approx(0.0)
delay = limiter.acquire("openai", rate_per_minute=2, burst=1)
assert delay == pytest.approx(30.0, rel=1e-3)
current[0] += 30.0
assert limiter.acquire("openai", rate_per_minute=2, burst=1) == pytest.approx(0.0)
def test_llm_response_cache_ttl_and_lru() -> None:
"""Validate cache expiration and eviction semantics."""
current = [0.0]
def fake_time() -> float:
return current[0]
cache = LLMResponseCache(max_size=2, default_ttl=10, time_func=fake_time)
cache.set("key1", {"value": 1})
assert cache.get("key1") == {"value": 1}
current[0] += 11
assert cache.get("key1") is None
cache.set("key1", {"value": 1})
cache.set("key2", {"value": 2})
assert cache.get("key1") == {"value": 1}
cache.set("key3", {"value": 3})
assert cache.get("key2") is None
assert cache.get("key1") == {"value": 1}
assert cache.get("key3") == {"value": 3}