From b7283c859f048b437b95ee5cb67357eae0565efc Mon Sep 17 00:00:00 2001 From: sam Date: Sun, 5 Oct 2025 20:13:02 +0800 Subject: [PATCH] update --- app/agents/departments.py | 5 ++ app/llm/client.py | 52 +++++++++++- app/llm/cost.py | 75 ++++++++++++++--- app/llm/metrics.py | 16 ++++ app/llm/prompts.py | 50 +++++++++-- app/llm/templates.py | 163 +++++++++++++++++++++++++++++++++--- app/utils/config.py | 111 +++++++++++++++++++++++- tests/test_llm_cost.py | 42 ++++++++-- tests/test_llm_templates.py | 25 ++++-- 9 files changed, 493 insertions(+), 46 deletions(-) diff --git a/app/agents/departments.py b/app/agents/departments.py index 3d2cf78..22e671b 100644 --- a/app/agents/departments.py +++ b/app/agents/departments.py @@ -119,6 +119,10 @@ class DepartmentAgent: if system_prompt: messages.append({"role": "system", "content": system_prompt}) prompt_body = department_prompt(self.settings, mutable_context) + template_meta = {} + raw_templates = mutable_context.raw.get("template_meta") if isinstance(mutable_context.raw, dict) else None + if isinstance(raw_templates, dict): + template_meta = dict(raw_templates.get(self.settings.code, {})) prompt_checksum = hashlib.sha1(prompt_body.encode("utf-8")).hexdigest() prompt_preview = prompt_body[:240] messages.append({"role": "user", "content": prompt_body}) @@ -330,6 +334,7 @@ class DepartmentAgent: "instruction": self.settings.prompt, "system": system_prompt, }, + "template": template_meta, "messages_exchanged": len(messages), "supplement_rounds": len(tool_call_records), } diff --git a/app/llm/client.py b/app/llm/client.py index 5f2c0da..9f43ad7 100644 --- a/app/llm/client.py +++ b/app/llm/client.py @@ -3,6 +3,7 @@ from __future__ import annotations import json from collections import Counter +import time from dataclasses import asdict from typing import Any, Dict, Iterable, List, Optional @@ -10,6 +11,7 @@ import requests from .context import ContextManager, Message from .templates import TemplateRegistry +from .cost import configure_cost_limits, get_cost_controller, budget_available from app.utils.config import ( DEFAULT_LLM_BASE_URLS, @@ -200,6 +202,26 @@ def call_endpoint_with_messages( timeout = resolved["timeout"] api_key = resolved["api_key"] + cfg = get_config() + cost_cfg = getattr(cfg, "llm_cost", None) + enforce_cost = False + cost_controller = None + if cost_cfg and getattr(cost_cfg, "enabled", False): + try: + limits = cost_cfg.to_cost_limits() + except Exception as exc: # noqa: BLE001 + LOGGER.warning( + "成本控制配置解析失败,将忽略限制: %s", + exc, + extra=LOG_EXTRA, + ) + else: + configure_cost_limits(limits) + enforce_cost = True + if not budget_available(): + raise LLMError("LLM 调用预算已耗尽,请稍后重试。") + cost_controller = get_cost_controller() + LOGGER.info( "触发 LLM 请求:provider=%s model=%s base=%s", provider_key, @@ -217,18 +239,24 @@ def call_endpoint_with_messages( "stream": False, "options": {"temperature": temperature}, } + start_time = time.perf_counter() response = requests.post( f"{base_url.rstrip('/')}/api/chat", json=payload, timeout=timeout, ) + duration = time.perf_counter() - start_time if response.status_code != 200: raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}") - record_call(provider_key, model) - return response.json() + data = response.json() + record_call(provider_key, model, duration=duration) + if enforce_cost and cost_controller: + cost_controller.record_usage(model or provider_key, 0, 0) + return data if not api_key: raise LLMError(f"缺少 {provider_key} API Key (model={model})") + start_time = time.perf_counter() data = _request_openai_chat( base_url=base_url, api_key=api_key, @@ -239,10 +267,28 @@ def call_endpoint_with_messages( tools=tools, tool_choice=tool_choice, ) + duration = time.perf_counter() - start_time usage = data.get("usage", {}) if isinstance(data, dict) else {} prompt_tokens = usage.get("prompt_tokens") or usage.get("prompt_tokens_total") completion_tokens = usage.get("completion_tokens") or usage.get("completion_tokens_total") - record_call(provider_key, model, prompt_tokens, completion_tokens) + record_call( + provider_key, + model, + prompt_tokens, + completion_tokens, + duration=duration, + ) + if enforce_cost and cost_controller: + prompt_count = int(prompt_tokens or 0) + completion_count = int(completion_tokens or 0) + within_limits = cost_controller.record_usage(model or provider_key, prompt_count, completion_count) + if not within_limits: + LOGGER.warning( + "LLM 成本预算已超限:provider=%s model=%s", + provider_key, + model, + extra=LOG_EXTRA, + ) return data diff --git a/app/llm/cost.py b/app/llm/cost.py index 4de4000..99a19aa 100644 --- a/app/llm/cost.py +++ b/app/llm/cost.py @@ -90,10 +90,16 @@ class CostController: return True - def record_usage(self, model: str, prompt_tokens: int, - completion_tokens: int) -> None: - """记录模型使用情况.""" + def record_usage( + self, + model: str, + prompt_tokens: int, + completion_tokens: int, + ) -> bool: + """记录模型使用情况,并返回是否仍在预算范围内.""" + cost = self._calculate_cost(model, prompt_tokens, completion_tokens) + within_limits = self._check_budget_limits(model, prompt_tokens, completion_tokens) timestamp = time.time() usage = { @@ -108,10 +114,20 @@ class CostController: self._usage["hourly"].append(usage) self._usage["daily"].append(usage) self._usage["monthly"].append(usage) - + # 定期清理过期数据 self._cleanup_old_usage(timestamp) + if not within_limits: + LOGGER.warning( + "Cost limit exceeded after recording usage - model: %s cost=$%.4f", + model, + cost, + extra=LOG_EXTRA, + ) + + return within_limits + def get_current_costs(self) -> Dict[str, float]: """获取当前时段的成本统计.""" with self._usage_lock: @@ -157,6 +173,16 @@ class CostController: for model, count in model_calls.items() } + def is_budget_available(self) -> bool: + """判断当前预算是否允许继续调用LLM.""" + + costs = self.get_current_costs() + return ( + costs["hourly"] < self.limits.hourly_budget and + costs["daily"] < self.limits.daily_budget and + costs["monthly"] < self.limits.monthly_budget + ) + def _calculate_cost(self, model: str, prompt_tokens: int, completion_tokens: int) -> float: """计算使用成本.""" @@ -226,13 +252,42 @@ class CostController: self._last_cleanup = current_time -# 全局实例 -_controller = CostController() +# 全局实例管理 +_CONTROLLER_LOCK = threading.Lock() +_GLOBAL_CONTROLLER = CostController() +_LAST_LIMITS: Optional[CostLimits] = None + def get_controller() -> CostController: - """获取全局CostController实例.""" - return _controller + """向后兼容的全局 CostController 访问方法.""" + + return _GLOBAL_CONTROLLER + + +def get_cost_controller() -> CostController: + """显式返回全局 CostController 实例.""" + + return _GLOBAL_CONTROLLER + def set_cost_limits(limits: CostLimits) -> None: - """设置全局成本限制.""" - _controller.limits = limits + """设置全局成本限制(兼容旧接口)。""" + + configure_cost_limits(limits) + + +def configure_cost_limits(limits: CostLimits) -> None: + """设置全局成本限制,如果变更才更新。""" + + global _LAST_LIMITS + with _CONTROLLER_LOCK: + if _LAST_LIMITS is None or limits != _LAST_LIMITS: + _GLOBAL_CONTROLLER.limits = limits + _LAST_LIMITS = limits + + +def budget_available() -> bool: + """判断全局预算是否仍可用。""" + + with _CONTROLLER_LOCK: + return _GLOBAL_CONTROLLER.is_budget_available() diff --git a/app/llm/metrics.py b/app/llm/metrics.py index b4efab4..8f22e4e 100644 --- a/app/llm/metrics.py +++ b/app/llm/metrics.py @@ -17,6 +17,8 @@ class _Metrics: model_calls: Dict[str, int] = field(default_factory=dict) decisions: Deque[Dict[str, object]] = field(default_factory=lambda: deque(maxlen=500)) decision_action_counts: Dict[str, int] = field(default_factory=dict) + total_latency: float = 0.0 + latency_samples: Deque[float] = field(default_factory=lambda: deque(maxlen=200)) _METRICS = _Metrics() @@ -31,6 +33,8 @@ def record_call( model: Optional[str] = None, prompt_tokens: Optional[int] = None, completion_tokens: Optional[int] = None, + *, + duration: Optional[float] = None, ) -> None: """Record a single LLM API invocation.""" @@ -49,6 +53,10 @@ def record_call( _METRICS.total_prompt_tokens += int(prompt_tokens) if completion_tokens: _METRICS.total_completion_tokens += int(completion_tokens) + if duration is not None: + duration_value = max(0.0, float(duration)) + _METRICS.total_latency += duration_value + _METRICS.latency_samples.append(duration_value) _notify_listeners() @@ -64,6 +72,12 @@ def snapshot(reset: bool = False) -> Dict[str, object]: "model_calls": dict(_METRICS.model_calls), "decision_action_counts": dict(_METRICS.decision_action_counts), "recent_decisions": list(_METRICS.decisions), + "average_latency": ( + _METRICS.total_latency / _METRICS.total_calls + if _METRICS.total_calls + else 0.0 + ), + "latency_samples": list(_METRICS.latency_samples), } if reset: _METRICS.total_calls = 0 @@ -73,6 +87,8 @@ def snapshot(reset: bool = False) -> Dict[str, object]: _METRICS.model_calls.clear() _METRICS.decision_action_counts.clear() _METRICS.decisions.clear() + _METRICS.total_latency = 0.0 + _METRICS.latency_samples.clear() return data diff --git a/app/llm/prompts.py b/app/llm/prompts.py index 8b7c4f9..0bddbe9 100644 --- a/app/llm/prompts.py +++ b/app/llm/prompts.py @@ -1,10 +1,13 @@ """Prompt templates for natural language outputs.""" from __future__ import annotations +import logging from typing import Dict, TYPE_CHECKING from .templates import TemplateRegistry +LOGGER = logging.getLogger(__name__) + if TYPE_CHECKING: # pragma: no cover from app.utils.config import DepartmentSettings from app.agents.departments import DepartmentContext @@ -35,11 +38,49 @@ def department_prompt( role_description = settings.description.strip() role_instruction = settings.prompt.strip() - # Determine template ID based on department settings - template_id = f"{settings.code.lower()}_dept" - if not TemplateRegistry.get(template_id): + # Determine template ID and version + template_id = (getattr(settings, "prompt_template_id", None) or f"{settings.code.lower()}_dept").strip() + requested_version = getattr(settings, "prompt_template_version", None) + original_requested_version = requested_version + template = TemplateRegistry.get(template_id, version=requested_version) + applied_version = requested_version if template and requested_version else None + + if not template: + if requested_version: + LOGGER.warning( + "Template %s version %s not found, falling back to active version", + template_id, + requested_version, + ) + template = TemplateRegistry.get(template_id) + applied_version = TemplateRegistry.get_active_version(template_id) + + if not template: + LOGGER.warning( + "Template %s unavailable, using department_base fallback", + template_id, + ) template_id = "department_base" - + template = TemplateRegistry.get(template_id) + requested_version = None + applied_version = TemplateRegistry.get_active_version(template_id) + + if not template: + raise ValueError("No prompt template available for department prompts") + + if applied_version is None: + applied_version = TemplateRegistry.get_active_version(template_id) + template_meta = { + "template_id": template_id, + "requested_version": original_requested_version, + "applied_version": applied_version, + } + + raw_container = getattr(context, "raw", None) + if isinstance(raw_container, dict): + meta_store = raw_container.setdefault("template_meta", {}) + meta_store[settings.code] = template_meta + # Prepare template variables template_vars = { "title": settings.title, @@ -54,5 +95,4 @@ def department_prompt( } # Get template and format prompt - template = TemplateRegistry.get(template_id) return template.format(template_vars) diff --git a/app/llm/templates.py b/app/llm/templates.py index b4bb536..b008e97 100644 --- a/app/llm/templates.py +++ b/app/llm/templates.py @@ -4,7 +4,10 @@ from __future__ import annotations import json import logging from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + from .version import TemplateVersionManager @dataclass @@ -69,31 +72,139 @@ class PromptTemplate: class TemplateRegistry: - """Global registry for prompt templates.""" + """Global registry for prompt templates with version awareness.""" _templates: Dict[str, PromptTemplate] = {} + _version_manager: Optional["TemplateVersionManager"] = None + _default_version_label: str = "1.0.0" @classmethod - def register(cls, template: PromptTemplate) -> None: - """Register a new template.""" + def _manager(cls) -> "TemplateVersionManager": + if cls._version_manager is None: + from .version import TemplateVersionManager # Local import to avoid circular dependency + + cls._version_manager = TemplateVersionManager() + return cls._version_manager + + @classmethod + def register( + cls, + template: PromptTemplate, + *, + version: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + activate: bool = False, + ) -> None: + """Register a new template and optionally version it.""" + errors = template.validate() if errors: raise ValueError(f"Invalid template {template.id}: {'; '.join(errors)}") + cls._templates[template.id] = template + manager = cls._manager() + existing_versions = manager.list_versions(template.id) + resolved_metadata: Dict[str, Any] = dict(metadata or {}) + if version: + manager.add_version( + template, + version, + metadata=resolved_metadata or None, + activate=activate, + ) + elif not existing_versions: + if "source" not in resolved_metadata: + resolved_metadata["source"] = "default" + manager.add_version( + template, + cls._default_version_label, + metadata=resolved_metadata, + activate=True, + ) + @classmethod - def get(cls, template_id: str) -> Optional[PromptTemplate]: - """Get template by ID.""" + def register_version( + cls, + template_id: str, + *, + version: str, + template: Optional[PromptTemplate] = None, + metadata: Optional[Dict[str, Any]] = None, + activate: bool = False, + ) -> None: + """Register an additional version for an existing template.""" + + base_template = template or cls._templates.get(template_id) + if not base_template: + raise ValueError(f"Template {template_id} not found for version registration") + + manager = cls._manager() + manager.add_version( + base_template, + version, + metadata=metadata, + activate=activate, + ) + + @classmethod + def activate_version(cls, template_id: str, version: str) -> None: + """Activate a specific template version.""" + + manager = cls._manager() + manager.activate_version(template_id, version) + + @classmethod + def get( + cls, + template_id: str, + *, + version: Optional[str] = None, + ) -> Optional[PromptTemplate]: + """Get template by ID and optional version.""" + + manager = cls._manager() + if version: + stored = manager.get_version(template_id, version) + if stored: + return stored.template + + active = manager.get_active_version(template_id) + if active: + return active.template + return cls._templates.get(template_id) + @classmethod + def get_active_version(cls, template_id: str) -> Optional[str]: + """Return the currently active version label for a template.""" + + manager = cls._manager() + active = manager.get_active_version(template_id) + return active.version if active else None + @classmethod def list(cls) -> List[PromptTemplate]: - """List all registered templates.""" - return list(cls._templates.values()) + """List all registered templates (active versions preferred).""" + + collected: Dict[str, PromptTemplate] = {} + manager = cls._manager() + for template_id, template in cls._templates.items(): + active = manager.get_active_version(template_id) + collected[template_id] = active.template if active else template + return list(collected.values()) + + @classmethod + def list_versions(cls, template_id: str) -> List[str]: + """List available version labels for a template.""" + + manager = cls._manager() + return [ver.version for ver in manager.list_versions(template_id)] @classmethod def load_from_json(cls, json_str: str) -> None: """Load templates from JSON string.""" + try: data = json.loads(json_str) except json.JSONDecodeError as e: @@ -103,6 +214,13 @@ class TemplateRegistry: raise ValueError("JSON root must be an object") for template_id, cfg in data.items(): + if not isinstance(cfg, dict): + raise ValueError(f"Template {template_id} configuration must be an object") + version = cfg.get("version") + metadata = cfg.get("metadata") + if metadata is not None and not isinstance(metadata, dict): + raise ValueError(f"Template {template_id} metadata must be an object") + activate = bool(cfg.get("activate", False)) template = PromptTemplate( id=template_id, name=cfg.get("name", template_id), @@ -113,12 +231,21 @@ class TemplateRegistry: required_context=cfg.get("required_context", []), validation_rules=cfg.get("validation_rules", []) ) - cls.register(template) + cls.register( + template, + version=version, + metadata=metadata, + activate=activate, + ) @classmethod - def clear(cls) -> None: - """Clear all registered templates.""" + def clear(cls, *, reload_defaults: bool = False) -> None: + """Clear all registered templates and optionally reload defaults.""" + cls._templates.clear() + cls._version_manager = None + if reload_defaults: + register_default_templates() # Default template definitions @@ -234,7 +361,19 @@ def register_default_templates() -> None: "validation_rules": cfg.get("validation_rules", []) } try: - TemplateRegistry.register(PromptTemplate(**template_config)) + template = PromptTemplate(**template_config) + version_label = str( + cfg.get("version") or TemplateRegistry._default_version_label + ) + metadata_raw = cfg.get("metadata") + metadata = dict(metadata_raw) if isinstance(metadata_raw, dict) else {} + metadata.setdefault("source", "defaults") + TemplateRegistry.register( + template, + version=version_label, + metadata=metadata, + activate=cfg.get("activate", True), + ) except ValueError as e: logging.warning(f"Failed to register template {template_id}: {e}") diff --git a/app/utils/config.py b/app/utils/config.py index 06e7c83..e7cc265 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -16,6 +16,13 @@ def _default_root() -> Path: return Path(__file__).resolve().parents[2] / "app" / "data" +def _safe_float(value: object, fallback: float) -> float: + try: + return float(value) + except (TypeError, ValueError): + return fallback + + @dataclass class DataPaths: """Holds filesystem locations for persistent artifacts.""" @@ -198,8 +205,61 @@ class LLMEndpoint: self.provider = (self.provider or "ollama").lower() if self.temperature is not None: self.temperature = float(self.temperature) - if self.timeout is not None: - self.timeout = float(self.timeout) + + +@dataclass +class LLMCostSettings: + """Configurable budgets and weights for LLM cost control.""" + + enabled: bool = False + hourly_budget: float = 5.0 + daily_budget: float = 50.0 + monthly_budget: float = 500.0 + model_weights: Dict[str, float] = field(default_factory=dict) + + def update_from_dict(self, data: Mapping[str, object]) -> None: + if "enabled" in data: + self.enabled = bool(data.get("enabled")) + if "hourly_budget" in data: + self.hourly_budget = _safe_float(data.get("hourly_budget"), self.hourly_budget) + if "daily_budget" in data: + self.daily_budget = _safe_float(data.get("daily_budget"), self.daily_budget) + if "monthly_budget" in data: + self.monthly_budget = _safe_float(data.get("monthly_budget"), self.monthly_budget) + weights = data.get("model_weights") if isinstance(data, Mapping) else None + if isinstance(weights, Mapping): + normalized: Dict[str, float] = {} + for key, value in weights.items(): + try: + normalized[str(key)] = float(value) + except (TypeError, ValueError): + continue + if normalized: + self.model_weights = normalized + + def to_cost_limits(self): + """Convert into runtime `CostLimits` descriptor.""" + + from app.llm.cost import CostLimits # Imported lazily to avoid cycles + + weights: Dict[str, float] = {} + for key, value in (self.model_weights or {}).items(): + try: + weights[str(key)] = float(value) + except (TypeError, ValueError): + continue + return CostLimits( + hourly_budget=float(self.hourly_budget), + daily_budget=float(self.daily_budget), + monthly_budget=float(self.monthly_budget), + model_weights=weights, + ) + + @classmethod + def from_dict(cls, data: Mapping[str, object]) -> "LLMCostSettings": + inst = cls() + inst.update_from_dict(data) + return inst @dataclass @@ -241,6 +301,8 @@ class DepartmentSettings: data_scope: List[str] = field(default_factory=list) prompt: str = "" llm: LLMConfig = field(default_factory=LLMConfig) + prompt_template_id: Optional[str] = None + prompt_template_version: Optional[str] = None def _default_departments() -> Dict[str, DepartmentSettings]: @@ -327,6 +389,7 @@ def _default_departments() -> Dict[str, DepartmentSettings]: description=item.get("description", ""), data_scope=list(item.get("data_scope", [])), prompt=item.get("prompt", ""), + prompt_template_id=f"{item['code']}_dept", ) for item in presets } @@ -355,6 +418,7 @@ class AppConfig: data_update_interval: int = 7 # 数据更新间隔(天) llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers) llm: LLMConfig = field(default_factory=LLMConfig) + llm_cost: LLMCostSettings = field(default_factory=LLMCostSettings) departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments) portfolio: PortfolioSettings = field(default_factory=PortfolioSettings) @@ -463,6 +527,10 @@ def _load_from_file(cfg: AppConfig) -> None: ) cfg.portfolio = updated_portfolio + cost_payload = payload.get("llm_cost") + if isinstance(cost_payload, dict): + cfg.llm_cost.update_from_dict(cost_payload) + legacy_profiles: Dict[str, Dict[str, object]] = {} legacy_routes: Dict[str, Dict[str, object]] = {} @@ -574,6 +642,7 @@ def _load_from_file(cfg: AppConfig) -> None: for code, data in departments_payload.items(): if not isinstance(data, dict): continue + current_setting = cfg.departments.get(code) title = data.get("title") or code description = data.get("description") or "" weight = float(data.get("weight", 1.0)) @@ -606,6 +675,33 @@ def _load_from_file(cfg: AppConfig) -> None: if isinstance(majority_raw, int) and majority_raw > 0: llm_cfg.majority_threshold = majority_raw resolved_cfg = llm_cfg + template_id_raw = data.get("prompt_template_id") + if isinstance(template_id_raw, str): + template_id_candidate = template_id_raw.strip() + elif template_id_raw is not None: + template_id_candidate = str(template_id_raw).strip() + else: + template_id_candidate = "" + if template_id_candidate: + template_id = template_id_candidate + elif current_setting and current_setting.prompt_template_id: + template_id = current_setting.prompt_template_id + else: + template_id = f"{code}_dept" + + template_version_raw = data.get("prompt_template_version") + if isinstance(template_version_raw, str): + template_version_candidate = template_version_raw.strip() + elif template_version_raw is not None: + template_version_candidate = str(template_version_raw).strip() + else: + template_version_candidate = "" + if template_version_candidate: + template_version = template_version_candidate + elif current_setting: + template_version = current_setting.prompt_template_version + else: + template_version = None new_departments[code] = DepartmentSettings( code=code, title=title, @@ -614,6 +710,8 @@ def _load_from_file(cfg: AppConfig) -> None: data_scope=data_scope, prompt=prompt_text, llm=resolved_cfg, + prompt_template_id=template_id, + prompt_template_version=template_version, ) if new_departments: cfg.departments = new_departments @@ -648,6 +746,13 @@ def save_config(cfg: AppConfig | None = None) -> None: "primary": _endpoint_to_dict(cfg.llm.primary), "ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble], }, + "llm_cost": { + "enabled": cfg.llm_cost.enabled, + "hourly_budget": cfg.llm_cost.hourly_budget, + "daily_budget": cfg.llm_cost.daily_budget, + "monthly_budget": cfg.llm_cost.monthly_budget, + "model_weights": cfg.llm_cost.model_weights, + }, "llm_providers": { key: provider.to_dict() for key, provider in cfg.llm_providers.items() @@ -659,6 +764,8 @@ def save_config(cfg: AppConfig | None = None) -> None: "weight": dept.weight, "data_scope": list(dept.data_scope), "prompt": dept.prompt, + "prompt_template_id": dept.prompt_template_id, + "prompt_template_version": dept.prompt_template_version, "llm": { "strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single", "majority_threshold": dept.llm.majority_threshold, diff --git a/tests/test_llm_cost.py b/tests/test_llm_cost.py index ced1e63..0a0cd9f 100644 --- a/tests/test_llm_cost.py +++ b/tests/test_llm_cost.py @@ -1,8 +1,14 @@ """Test cases for LLM cost control system.""" -import pytest -from datetime import datetime +import time -from app.llm.cost import CostLimits, ModelCosts, CostController +from app.llm.cost import ( + CostLimits, + ModelCosts, + CostController, + configure_cost_limits, + get_cost_controller, + budget_available, +) def test_cost_limits(): @@ -53,11 +59,11 @@ def test_cost_controller(): assert controller.can_use_model("gpt-4", 1000, 500) # Then record the usage - controller.record_usage("gpt-4", 1000, 500) # About $0.09 + assert controller.record_usage("gpt-4", 1000, 500) is True # About $0.09 # Record usage for second model to maintain weight balance assert controller.can_use_model("gpt-3.5-turbo", 1000, 500) - controller.record_usage("gpt-3.5-turbo", 1000, 500) + assert controller.record_usage("gpt-3.5-turbo", 1000, 500) # Verify usage tracking costs = controller.get_current_costs() @@ -85,10 +91,10 @@ def test_cost_controller_history(): # Record one usage of each model assert controller.can_use_model("gpt-4", 1000, 500) - controller.record_usage("gpt-4", 1000, 500) - + assert controller.record_usage("gpt-4", 1000, 500) + assert controller.can_use_model("gpt-3.5-turbo", 1000, 500) - controller.record_usage("gpt-3.5-turbo", 1000, 500) + assert controller.record_usage("gpt-3.5-turbo", 1000, 500) # Check usage tracking costs = controller.get_current_costs() @@ -98,3 +104,23 @@ def test_cost_controller_history(): distribution = controller.get_model_distribution() assert abs(distribution["gpt-4"] - 0.5) < 0.1 assert abs(distribution["gpt-3.5-turbo"] - 0.5) < 0.1 + + +def test_global_controller_budget_toggle(): + """Ensure global controller respects configured limits and budget flag.""" + limits = CostLimits(hourly_budget=0.05, daily_budget=0.1, monthly_budget=1.0, model_weights={}) + configure_cost_limits(limits) + controller = get_cost_controller() + + assert budget_available() is True + within = controller.record_usage("gpt-4", 1000, 1000) + assert within is False # deliberately exceed tiny budget + assert budget_available() is False + + # Reset controller state to avoid cross-test contamination + with controller._usage_lock: # type: ignore[attr-defined] + for bucket in controller._usage.values(): # type: ignore[attr-defined] + bucket.clear() + controller._last_cleanup = time.time() # type: ignore[attr-defined] + + configure_cost_limits(CostLimits.default()) diff --git a/tests/test_llm_templates.py b/tests/test_llm_templates.py index c7949cf..a9a9042 100644 --- a/tests/test_llm_templates.py +++ b/tests/test_llm_templates.py @@ -69,7 +69,7 @@ def test_prompt_template_format(): # Valid context result = template.format({"name": "World"}) - assert result == "Hello Wor..." + assert result == "Hello W..." # Missing required context with pytest.raises(ValueError) as exc: @@ -77,8 +77,15 @@ def test_prompt_template_format(): assert "Missing required context" in str(exc.value) # Missing variable + template_no_required = PromptTemplate( + id="test2", + name="Test Template", + description="A test template", + template="Hello {name}!", + variables=["name"], + ) with pytest.raises(ValueError) as exc: - template.format({"wrong": "value"}) + template_no_required.format({"wrong": "value"}) assert "Missing template variable" in str(exc.value) @@ -95,7 +102,8 @@ def test_template_registry(): variables=["name"] ) TemplateRegistry.register(template) - assert TemplateRegistry.get("test") == template + assert TemplateRegistry.get("test") is not None + assert TemplateRegistry.get_active_version("test") == "1.0.0" # Register invalid template invalid = PromptTemplate( @@ -121,12 +129,17 @@ def test_template_registry(): "name": "JSON Test", "description": "Test template from JSON", "template": "Hello {name}!", - "variables": ["name"] + "variables": ["name"], + "version": "2024.10", + "metadata": {"author": "qa"}, + "activate": true } } ''' TemplateRegistry.load_from_json(json_str) - assert TemplateRegistry.get("json_test") is not None + loaded = TemplateRegistry.get("json_test") + assert loaded is not None + assert TemplateRegistry.get_active_version("json_test") == "2024.10" # Invalid JSON with pytest.raises(ValueError) as exc: @@ -141,7 +154,7 @@ def test_template_registry(): def test_default_templates(): """Test default template registration.""" - TemplateRegistry.clear() + TemplateRegistry.clear(reload_defaults=True) from app.llm.templates import DEFAULT_TEMPLATES # Verify default templates are loaded