This commit is contained in:
sam 2025-10-05 20:13:02 +08:00
parent a553de78b4
commit b7283c859f
9 changed files with 493 additions and 46 deletions

View File

@ -119,6 +119,10 @@ class DepartmentAgent:
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
prompt_body = department_prompt(self.settings, mutable_context) prompt_body = department_prompt(self.settings, mutable_context)
template_meta = {}
raw_templates = mutable_context.raw.get("template_meta") if isinstance(mutable_context.raw, dict) else None
if isinstance(raw_templates, dict):
template_meta = dict(raw_templates.get(self.settings.code, {}))
prompt_checksum = hashlib.sha1(prompt_body.encode("utf-8")).hexdigest() prompt_checksum = hashlib.sha1(prompt_body.encode("utf-8")).hexdigest()
prompt_preview = prompt_body[:240] prompt_preview = prompt_body[:240]
messages.append({"role": "user", "content": prompt_body}) messages.append({"role": "user", "content": prompt_body})
@ -330,6 +334,7 @@ class DepartmentAgent:
"instruction": self.settings.prompt, "instruction": self.settings.prompt,
"system": system_prompt, "system": system_prompt,
}, },
"template": template_meta,
"messages_exchanged": len(messages), "messages_exchanged": len(messages),
"supplement_rounds": len(tool_call_records), "supplement_rounds": len(tool_call_records),
} }

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import json import json
from collections import Counter from collections import Counter
import time
from dataclasses import asdict from dataclasses import asdict
from typing import Any, Dict, Iterable, List, Optional from typing import Any, Dict, Iterable, List, Optional
@ -10,6 +11,7 @@ import requests
from .context import ContextManager, Message from .context import ContextManager, Message
from .templates import TemplateRegistry from .templates import TemplateRegistry
from .cost import configure_cost_limits, get_cost_controller, budget_available
from app.utils.config import ( from app.utils.config import (
DEFAULT_LLM_BASE_URLS, DEFAULT_LLM_BASE_URLS,
@ -200,6 +202,26 @@ def call_endpoint_with_messages(
timeout = resolved["timeout"] timeout = resolved["timeout"]
api_key = resolved["api_key"] api_key = resolved["api_key"]
cfg = get_config()
cost_cfg = getattr(cfg, "llm_cost", None)
enforce_cost = False
cost_controller = None
if cost_cfg and getattr(cost_cfg, "enabled", False):
try:
limits = cost_cfg.to_cost_limits()
except Exception as exc: # noqa: BLE001
LOGGER.warning(
"成本控制配置解析失败,将忽略限制: %s",
exc,
extra=LOG_EXTRA,
)
else:
configure_cost_limits(limits)
enforce_cost = True
if not budget_available():
raise LLMError("LLM 调用预算已耗尽,请稍后重试。")
cost_controller = get_cost_controller()
LOGGER.info( LOGGER.info(
"触发 LLM 请求provider=%s model=%s base=%s", "触发 LLM 请求provider=%s model=%s base=%s",
provider_key, provider_key,
@ -217,18 +239,24 @@ def call_endpoint_with_messages(
"stream": False, "stream": False,
"options": {"temperature": temperature}, "options": {"temperature": temperature},
} }
start_time = time.perf_counter()
response = requests.post( response = requests.post(
f"{base_url.rstrip('/')}/api/chat", f"{base_url.rstrip('/')}/api/chat",
json=payload, json=payload,
timeout=timeout, timeout=timeout,
) )
duration = time.perf_counter() - start_time
if response.status_code != 200: if response.status_code != 200:
raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}") raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}")
record_call(provider_key, model) data = response.json()
return response.json() record_call(provider_key, model, duration=duration)
if enforce_cost and cost_controller:
cost_controller.record_usage(model or provider_key, 0, 0)
return data
if not api_key: if not api_key:
raise LLMError(f"缺少 {provider_key} API Key (model={model})") raise LLMError(f"缺少 {provider_key} API Key (model={model})")
start_time = time.perf_counter()
data = _request_openai_chat( data = _request_openai_chat(
base_url=base_url, base_url=base_url,
api_key=api_key, api_key=api_key,
@ -239,10 +267,28 @@ def call_endpoint_with_messages(
tools=tools, tools=tools,
tool_choice=tool_choice, tool_choice=tool_choice,
) )
duration = time.perf_counter() - start_time
usage = data.get("usage", {}) if isinstance(data, dict) else {} usage = data.get("usage", {}) if isinstance(data, dict) else {}
prompt_tokens = usage.get("prompt_tokens") or usage.get("prompt_tokens_total") prompt_tokens = usage.get("prompt_tokens") or usage.get("prompt_tokens_total")
completion_tokens = usage.get("completion_tokens") or usage.get("completion_tokens_total") completion_tokens = usage.get("completion_tokens") or usage.get("completion_tokens_total")
record_call(provider_key, model, prompt_tokens, completion_tokens) record_call(
provider_key,
model,
prompt_tokens,
completion_tokens,
duration=duration,
)
if enforce_cost and cost_controller:
prompt_count = int(prompt_tokens or 0)
completion_count = int(completion_tokens or 0)
within_limits = cost_controller.record_usage(model or provider_key, prompt_count, completion_count)
if not within_limits:
LOGGER.warning(
"LLM 成本预算已超限provider=%s model=%s",
provider_key,
model,
extra=LOG_EXTRA,
)
return data return data

View File

@ -90,10 +90,16 @@ class CostController:
return True return True
def record_usage(self, model: str, prompt_tokens: int, def record_usage(
completion_tokens: int) -> None: self,
"""记录模型使用情况.""" model: str,
prompt_tokens: int,
completion_tokens: int,
) -> bool:
"""记录模型使用情况,并返回是否仍在预算范围内."""
cost = self._calculate_cost(model, prompt_tokens, completion_tokens) cost = self._calculate_cost(model, prompt_tokens, completion_tokens)
within_limits = self._check_budget_limits(model, prompt_tokens, completion_tokens)
timestamp = time.time() timestamp = time.time()
usage = { usage = {
@ -112,6 +118,16 @@ class CostController:
# 定期清理过期数据 # 定期清理过期数据
self._cleanup_old_usage(timestamp) self._cleanup_old_usage(timestamp)
if not within_limits:
LOGGER.warning(
"Cost limit exceeded after recording usage - model: %s cost=$%.4f",
model,
cost,
extra=LOG_EXTRA,
)
return within_limits
def get_current_costs(self) -> Dict[str, float]: def get_current_costs(self) -> Dict[str, float]:
"""获取当前时段的成本统计.""" """获取当前时段的成本统计."""
with self._usage_lock: with self._usage_lock:
@ -157,6 +173,16 @@ class CostController:
for model, count in model_calls.items() for model, count in model_calls.items()
} }
def is_budget_available(self) -> bool:
"""判断当前预算是否允许继续调用LLM."""
costs = self.get_current_costs()
return (
costs["hourly"] < self.limits.hourly_budget and
costs["daily"] < self.limits.daily_budget and
costs["monthly"] < self.limits.monthly_budget
)
def _calculate_cost(self, model: str, prompt_tokens: int, def _calculate_cost(self, model: str, prompt_tokens: int,
completion_tokens: int) -> float: completion_tokens: int) -> float:
"""计算使用成本.""" """计算使用成本."""
@ -226,13 +252,42 @@ class CostController:
self._last_cleanup = current_time self._last_cleanup = current_time
# 全局实例 # 全局实例管理
_controller = CostController() _CONTROLLER_LOCK = threading.Lock()
_GLOBAL_CONTROLLER = CostController()
_LAST_LIMITS: Optional[CostLimits] = None
def get_controller() -> CostController: def get_controller() -> CostController:
"""获取全局CostController实例.""" """向后兼容的全局 CostController 访问方法."""
return _controller
return _GLOBAL_CONTROLLER
def get_cost_controller() -> CostController:
"""显式返回全局 CostController 实例."""
return _GLOBAL_CONTROLLER
def set_cost_limits(limits: CostLimits) -> None: def set_cost_limits(limits: CostLimits) -> None:
"""设置全局成本限制.""" """设置全局成本限制(兼容旧接口)。"""
_controller.limits = limits
configure_cost_limits(limits)
def configure_cost_limits(limits: CostLimits) -> None:
"""设置全局成本限制,如果变更才更新。"""
global _LAST_LIMITS
with _CONTROLLER_LOCK:
if _LAST_LIMITS is None or limits != _LAST_LIMITS:
_GLOBAL_CONTROLLER.limits = limits
_LAST_LIMITS = limits
def budget_available() -> bool:
"""判断全局预算是否仍可用。"""
with _CONTROLLER_LOCK:
return _GLOBAL_CONTROLLER.is_budget_available()

View File

@ -17,6 +17,8 @@ class _Metrics:
model_calls: Dict[str, int] = field(default_factory=dict) model_calls: Dict[str, int] = field(default_factory=dict)
decisions: Deque[Dict[str, object]] = field(default_factory=lambda: deque(maxlen=500)) decisions: Deque[Dict[str, object]] = field(default_factory=lambda: deque(maxlen=500))
decision_action_counts: Dict[str, int] = field(default_factory=dict) decision_action_counts: Dict[str, int] = field(default_factory=dict)
total_latency: float = 0.0
latency_samples: Deque[float] = field(default_factory=lambda: deque(maxlen=200))
_METRICS = _Metrics() _METRICS = _Metrics()
@ -31,6 +33,8 @@ def record_call(
model: Optional[str] = None, model: Optional[str] = None,
prompt_tokens: Optional[int] = None, prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None, completion_tokens: Optional[int] = None,
*,
duration: Optional[float] = None,
) -> None: ) -> None:
"""Record a single LLM API invocation.""" """Record a single LLM API invocation."""
@ -49,6 +53,10 @@ def record_call(
_METRICS.total_prompt_tokens += int(prompt_tokens) _METRICS.total_prompt_tokens += int(prompt_tokens)
if completion_tokens: if completion_tokens:
_METRICS.total_completion_tokens += int(completion_tokens) _METRICS.total_completion_tokens += int(completion_tokens)
if duration is not None:
duration_value = max(0.0, float(duration))
_METRICS.total_latency += duration_value
_METRICS.latency_samples.append(duration_value)
_notify_listeners() _notify_listeners()
@ -64,6 +72,12 @@ def snapshot(reset: bool = False) -> Dict[str, object]:
"model_calls": dict(_METRICS.model_calls), "model_calls": dict(_METRICS.model_calls),
"decision_action_counts": dict(_METRICS.decision_action_counts), "decision_action_counts": dict(_METRICS.decision_action_counts),
"recent_decisions": list(_METRICS.decisions), "recent_decisions": list(_METRICS.decisions),
"average_latency": (
_METRICS.total_latency / _METRICS.total_calls
if _METRICS.total_calls
else 0.0
),
"latency_samples": list(_METRICS.latency_samples),
} }
if reset: if reset:
_METRICS.total_calls = 0 _METRICS.total_calls = 0
@ -73,6 +87,8 @@ def snapshot(reset: bool = False) -> Dict[str, object]:
_METRICS.model_calls.clear() _METRICS.model_calls.clear()
_METRICS.decision_action_counts.clear() _METRICS.decision_action_counts.clear()
_METRICS.decisions.clear() _METRICS.decisions.clear()
_METRICS.total_latency = 0.0
_METRICS.latency_samples.clear()
return data return data

View File

@ -1,10 +1,13 @@
"""Prompt templates for natural language outputs.""" """Prompt templates for natural language outputs."""
from __future__ import annotations from __future__ import annotations
import logging
from typing import Dict, TYPE_CHECKING from typing import Dict, TYPE_CHECKING
from .templates import TemplateRegistry from .templates import TemplateRegistry
LOGGER = logging.getLogger(__name__)
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from app.utils.config import DepartmentSettings from app.utils.config import DepartmentSettings
from app.agents.departments import DepartmentContext from app.agents.departments import DepartmentContext
@ -35,10 +38,48 @@ def department_prompt(
role_description = settings.description.strip() role_description = settings.description.strip()
role_instruction = settings.prompt.strip() role_instruction = settings.prompt.strip()
# Determine template ID based on department settings # Determine template ID and version
template_id = f"{settings.code.lower()}_dept" template_id = (getattr(settings, "prompt_template_id", None) or f"{settings.code.lower()}_dept").strip()
if not TemplateRegistry.get(template_id): requested_version = getattr(settings, "prompt_template_version", None)
original_requested_version = requested_version
template = TemplateRegistry.get(template_id, version=requested_version)
applied_version = requested_version if template and requested_version else None
if not template:
if requested_version:
LOGGER.warning(
"Template %s version %s not found, falling back to active version",
template_id,
requested_version,
)
template = TemplateRegistry.get(template_id)
applied_version = TemplateRegistry.get_active_version(template_id)
if not template:
LOGGER.warning(
"Template %s unavailable, using department_base fallback",
template_id,
)
template_id = "department_base" template_id = "department_base"
template = TemplateRegistry.get(template_id)
requested_version = None
applied_version = TemplateRegistry.get_active_version(template_id)
if not template:
raise ValueError("No prompt template available for department prompts")
if applied_version is None:
applied_version = TemplateRegistry.get_active_version(template_id)
template_meta = {
"template_id": template_id,
"requested_version": original_requested_version,
"applied_version": applied_version,
}
raw_container = getattr(context, "raw", None)
if isinstance(raw_container, dict):
meta_store = raw_container.setdefault("template_meta", {})
meta_store[settings.code] = template_meta
# Prepare template variables # Prepare template variables
template_vars = { template_vars = {
@ -54,5 +95,4 @@ def department_prompt(
} }
# Get template and format prompt # Get template and format prompt
template = TemplateRegistry.get(template_id)
return template.format(template_vars) return template.format(template_vars)

View File

@ -4,7 +4,10 @@ from __future__ import annotations
import json import json
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, TYPE_CHECKING
if TYPE_CHECKING: # pragma: no cover
from .version import TemplateVersionManager
@dataclass @dataclass
@ -69,31 +72,139 @@ class PromptTemplate:
class TemplateRegistry: class TemplateRegistry:
"""Global registry for prompt templates.""" """Global registry for prompt templates with version awareness."""
_templates: Dict[str, PromptTemplate] = {} _templates: Dict[str, PromptTemplate] = {}
_version_manager: Optional["TemplateVersionManager"] = None
_default_version_label: str = "1.0.0"
@classmethod @classmethod
def register(cls, template: PromptTemplate) -> None: def _manager(cls) -> "TemplateVersionManager":
"""Register a new template.""" if cls._version_manager is None:
from .version import TemplateVersionManager # Local import to avoid circular dependency
cls._version_manager = TemplateVersionManager()
return cls._version_manager
@classmethod
def register(
cls,
template: PromptTemplate,
*,
version: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
activate: bool = False,
) -> None:
"""Register a new template and optionally version it."""
errors = template.validate() errors = template.validate()
if errors: if errors:
raise ValueError(f"Invalid template {template.id}: {'; '.join(errors)}") raise ValueError(f"Invalid template {template.id}: {'; '.join(errors)}")
cls._templates[template.id] = template cls._templates[template.id] = template
manager = cls._manager()
existing_versions = manager.list_versions(template.id)
resolved_metadata: Dict[str, Any] = dict(metadata or {})
if version:
manager.add_version(
template,
version,
metadata=resolved_metadata or None,
activate=activate,
)
elif not existing_versions:
if "source" not in resolved_metadata:
resolved_metadata["source"] = "default"
manager.add_version(
template,
cls._default_version_label,
metadata=resolved_metadata,
activate=True,
)
@classmethod @classmethod
def get(cls, template_id: str) -> Optional[PromptTemplate]: def register_version(
"""Get template by ID.""" cls,
template_id: str,
*,
version: str,
template: Optional[PromptTemplate] = None,
metadata: Optional[Dict[str, Any]] = None,
activate: bool = False,
) -> None:
"""Register an additional version for an existing template."""
base_template = template or cls._templates.get(template_id)
if not base_template:
raise ValueError(f"Template {template_id} not found for version registration")
manager = cls._manager()
manager.add_version(
base_template,
version,
metadata=metadata,
activate=activate,
)
@classmethod
def activate_version(cls, template_id: str, version: str) -> None:
"""Activate a specific template version."""
manager = cls._manager()
manager.activate_version(template_id, version)
@classmethod
def get(
cls,
template_id: str,
*,
version: Optional[str] = None,
) -> Optional[PromptTemplate]:
"""Get template by ID and optional version."""
manager = cls._manager()
if version:
stored = manager.get_version(template_id, version)
if stored:
return stored.template
active = manager.get_active_version(template_id)
if active:
return active.template
return cls._templates.get(template_id) return cls._templates.get(template_id)
@classmethod
def get_active_version(cls, template_id: str) -> Optional[str]:
"""Return the currently active version label for a template."""
manager = cls._manager()
active = manager.get_active_version(template_id)
return active.version if active else None
@classmethod @classmethod
def list(cls) -> List[PromptTemplate]: def list(cls) -> List[PromptTemplate]:
"""List all registered templates.""" """List all registered templates (active versions preferred)."""
return list(cls._templates.values())
collected: Dict[str, PromptTemplate] = {}
manager = cls._manager()
for template_id, template in cls._templates.items():
active = manager.get_active_version(template_id)
collected[template_id] = active.template if active else template
return list(collected.values())
@classmethod
def list_versions(cls, template_id: str) -> List[str]:
"""List available version labels for a template."""
manager = cls._manager()
return [ver.version for ver in manager.list_versions(template_id)]
@classmethod @classmethod
def load_from_json(cls, json_str: str) -> None: def load_from_json(cls, json_str: str) -> None:
"""Load templates from JSON string.""" """Load templates from JSON string."""
try: try:
data = json.loads(json_str) data = json.loads(json_str)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
@ -103,6 +214,13 @@ class TemplateRegistry:
raise ValueError("JSON root must be an object") raise ValueError("JSON root must be an object")
for template_id, cfg in data.items(): for template_id, cfg in data.items():
if not isinstance(cfg, dict):
raise ValueError(f"Template {template_id} configuration must be an object")
version = cfg.get("version")
metadata = cfg.get("metadata")
if metadata is not None and not isinstance(metadata, dict):
raise ValueError(f"Template {template_id} metadata must be an object")
activate = bool(cfg.get("activate", False))
template = PromptTemplate( template = PromptTemplate(
id=template_id, id=template_id,
name=cfg.get("name", template_id), name=cfg.get("name", template_id),
@ -113,12 +231,21 @@ class TemplateRegistry:
required_context=cfg.get("required_context", []), required_context=cfg.get("required_context", []),
validation_rules=cfg.get("validation_rules", []) validation_rules=cfg.get("validation_rules", [])
) )
cls.register(template) cls.register(
template,
version=version,
metadata=metadata,
activate=activate,
)
@classmethod @classmethod
def clear(cls) -> None: def clear(cls, *, reload_defaults: bool = False) -> None:
"""Clear all registered templates.""" """Clear all registered templates and optionally reload defaults."""
cls._templates.clear() cls._templates.clear()
cls._version_manager = None
if reload_defaults:
register_default_templates()
# Default template definitions # Default template definitions
@ -234,7 +361,19 @@ def register_default_templates() -> None:
"validation_rules": cfg.get("validation_rules", []) "validation_rules": cfg.get("validation_rules", [])
} }
try: try:
TemplateRegistry.register(PromptTemplate(**template_config)) template = PromptTemplate(**template_config)
version_label = str(
cfg.get("version") or TemplateRegistry._default_version_label
)
metadata_raw = cfg.get("metadata")
metadata = dict(metadata_raw) if isinstance(metadata_raw, dict) else {}
metadata.setdefault("source", "defaults")
TemplateRegistry.register(
template,
version=version_label,
metadata=metadata,
activate=cfg.get("activate", True),
)
except ValueError as e: except ValueError as e:
logging.warning(f"Failed to register template {template_id}: {e}") logging.warning(f"Failed to register template {template_id}: {e}")

View File

@ -16,6 +16,13 @@ def _default_root() -> Path:
return Path(__file__).resolve().parents[2] / "app" / "data" return Path(__file__).resolve().parents[2] / "app" / "data"
def _safe_float(value: object, fallback: float) -> float:
try:
return float(value)
except (TypeError, ValueError):
return fallback
@dataclass @dataclass
class DataPaths: class DataPaths:
"""Holds filesystem locations for persistent artifacts.""" """Holds filesystem locations for persistent artifacts."""
@ -198,8 +205,61 @@ class LLMEndpoint:
self.provider = (self.provider or "ollama").lower() self.provider = (self.provider or "ollama").lower()
if self.temperature is not None: if self.temperature is not None:
self.temperature = float(self.temperature) self.temperature = float(self.temperature)
if self.timeout is not None:
self.timeout = float(self.timeout)
@dataclass
class LLMCostSettings:
"""Configurable budgets and weights for LLM cost control."""
enabled: bool = False
hourly_budget: float = 5.0
daily_budget: float = 50.0
monthly_budget: float = 500.0
model_weights: Dict[str, float] = field(default_factory=dict)
def update_from_dict(self, data: Mapping[str, object]) -> None:
if "enabled" in data:
self.enabled = bool(data.get("enabled"))
if "hourly_budget" in data:
self.hourly_budget = _safe_float(data.get("hourly_budget"), self.hourly_budget)
if "daily_budget" in data:
self.daily_budget = _safe_float(data.get("daily_budget"), self.daily_budget)
if "monthly_budget" in data:
self.monthly_budget = _safe_float(data.get("monthly_budget"), self.monthly_budget)
weights = data.get("model_weights") if isinstance(data, Mapping) else None
if isinstance(weights, Mapping):
normalized: Dict[str, float] = {}
for key, value in weights.items():
try:
normalized[str(key)] = float(value)
except (TypeError, ValueError):
continue
if normalized:
self.model_weights = normalized
def to_cost_limits(self):
"""Convert into runtime `CostLimits` descriptor."""
from app.llm.cost import CostLimits # Imported lazily to avoid cycles
weights: Dict[str, float] = {}
for key, value in (self.model_weights or {}).items():
try:
weights[str(key)] = float(value)
except (TypeError, ValueError):
continue
return CostLimits(
hourly_budget=float(self.hourly_budget),
daily_budget=float(self.daily_budget),
monthly_budget=float(self.monthly_budget),
model_weights=weights,
)
@classmethod
def from_dict(cls, data: Mapping[str, object]) -> "LLMCostSettings":
inst = cls()
inst.update_from_dict(data)
return inst
@dataclass @dataclass
@ -241,6 +301,8 @@ class DepartmentSettings:
data_scope: List[str] = field(default_factory=list) data_scope: List[str] = field(default_factory=list)
prompt: str = "" prompt: str = ""
llm: LLMConfig = field(default_factory=LLMConfig) llm: LLMConfig = field(default_factory=LLMConfig)
prompt_template_id: Optional[str] = None
prompt_template_version: Optional[str] = None
def _default_departments() -> Dict[str, DepartmentSettings]: def _default_departments() -> Dict[str, DepartmentSettings]:
@ -327,6 +389,7 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
description=item.get("description", ""), description=item.get("description", ""),
data_scope=list(item.get("data_scope", [])), data_scope=list(item.get("data_scope", [])),
prompt=item.get("prompt", ""), prompt=item.get("prompt", ""),
prompt_template_id=f"{item['code']}_dept",
) )
for item in presets for item in presets
} }
@ -355,6 +418,7 @@ class AppConfig:
data_update_interval: int = 7 # 数据更新间隔(天) data_update_interval: int = 7 # 数据更新间隔(天)
llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers) llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers)
llm: LLMConfig = field(default_factory=LLMConfig) llm: LLMConfig = field(default_factory=LLMConfig)
llm_cost: LLMCostSettings = field(default_factory=LLMCostSettings)
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments) departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
portfolio: PortfolioSettings = field(default_factory=PortfolioSettings) portfolio: PortfolioSettings = field(default_factory=PortfolioSettings)
@ -463,6 +527,10 @@ def _load_from_file(cfg: AppConfig) -> None:
) )
cfg.portfolio = updated_portfolio cfg.portfolio = updated_portfolio
cost_payload = payload.get("llm_cost")
if isinstance(cost_payload, dict):
cfg.llm_cost.update_from_dict(cost_payload)
legacy_profiles: Dict[str, Dict[str, object]] = {} legacy_profiles: Dict[str, Dict[str, object]] = {}
legacy_routes: Dict[str, Dict[str, object]] = {} legacy_routes: Dict[str, Dict[str, object]] = {}
@ -574,6 +642,7 @@ def _load_from_file(cfg: AppConfig) -> None:
for code, data in departments_payload.items(): for code, data in departments_payload.items():
if not isinstance(data, dict): if not isinstance(data, dict):
continue continue
current_setting = cfg.departments.get(code)
title = data.get("title") or code title = data.get("title") or code
description = data.get("description") or "" description = data.get("description") or ""
weight = float(data.get("weight", 1.0)) weight = float(data.get("weight", 1.0))
@ -606,6 +675,33 @@ def _load_from_file(cfg: AppConfig) -> None:
if isinstance(majority_raw, int) and majority_raw > 0: if isinstance(majority_raw, int) and majority_raw > 0:
llm_cfg.majority_threshold = majority_raw llm_cfg.majority_threshold = majority_raw
resolved_cfg = llm_cfg resolved_cfg = llm_cfg
template_id_raw = data.get("prompt_template_id")
if isinstance(template_id_raw, str):
template_id_candidate = template_id_raw.strip()
elif template_id_raw is not None:
template_id_candidate = str(template_id_raw).strip()
else:
template_id_candidate = ""
if template_id_candidate:
template_id = template_id_candidate
elif current_setting and current_setting.prompt_template_id:
template_id = current_setting.prompt_template_id
else:
template_id = f"{code}_dept"
template_version_raw = data.get("prompt_template_version")
if isinstance(template_version_raw, str):
template_version_candidate = template_version_raw.strip()
elif template_version_raw is not None:
template_version_candidate = str(template_version_raw).strip()
else:
template_version_candidate = ""
if template_version_candidate:
template_version = template_version_candidate
elif current_setting:
template_version = current_setting.prompt_template_version
else:
template_version = None
new_departments[code] = DepartmentSettings( new_departments[code] = DepartmentSettings(
code=code, code=code,
title=title, title=title,
@ -614,6 +710,8 @@ def _load_from_file(cfg: AppConfig) -> None:
data_scope=data_scope, data_scope=data_scope,
prompt=prompt_text, prompt=prompt_text,
llm=resolved_cfg, llm=resolved_cfg,
prompt_template_id=template_id,
prompt_template_version=template_version,
) )
if new_departments: if new_departments:
cfg.departments = new_departments cfg.departments = new_departments
@ -648,6 +746,13 @@ def save_config(cfg: AppConfig | None = None) -> None:
"primary": _endpoint_to_dict(cfg.llm.primary), "primary": _endpoint_to_dict(cfg.llm.primary),
"ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble], "ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble],
}, },
"llm_cost": {
"enabled": cfg.llm_cost.enabled,
"hourly_budget": cfg.llm_cost.hourly_budget,
"daily_budget": cfg.llm_cost.daily_budget,
"monthly_budget": cfg.llm_cost.monthly_budget,
"model_weights": cfg.llm_cost.model_weights,
},
"llm_providers": { "llm_providers": {
key: provider.to_dict() key: provider.to_dict()
for key, provider in cfg.llm_providers.items() for key, provider in cfg.llm_providers.items()
@ -659,6 +764,8 @@ def save_config(cfg: AppConfig | None = None) -> None:
"weight": dept.weight, "weight": dept.weight,
"data_scope": list(dept.data_scope), "data_scope": list(dept.data_scope),
"prompt": dept.prompt, "prompt": dept.prompt,
"prompt_template_id": dept.prompt_template_id,
"prompt_template_version": dept.prompt_template_version,
"llm": { "llm": {
"strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single", "strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
"majority_threshold": dept.llm.majority_threshold, "majority_threshold": dept.llm.majority_threshold,

View File

@ -1,8 +1,14 @@
"""Test cases for LLM cost control system.""" """Test cases for LLM cost control system."""
import pytest import time
from datetime import datetime
from app.llm.cost import CostLimits, ModelCosts, CostController from app.llm.cost import (
CostLimits,
ModelCosts,
CostController,
configure_cost_limits,
get_cost_controller,
budget_available,
)
def test_cost_limits(): def test_cost_limits():
@ -53,11 +59,11 @@ def test_cost_controller():
assert controller.can_use_model("gpt-4", 1000, 500) assert controller.can_use_model("gpt-4", 1000, 500)
# Then record the usage # Then record the usage
controller.record_usage("gpt-4", 1000, 500) # About $0.09 assert controller.record_usage("gpt-4", 1000, 500) is True # About $0.09
# Record usage for second model to maintain weight balance # Record usage for second model to maintain weight balance
assert controller.can_use_model("gpt-3.5-turbo", 1000, 500) assert controller.can_use_model("gpt-3.5-turbo", 1000, 500)
controller.record_usage("gpt-3.5-turbo", 1000, 500) assert controller.record_usage("gpt-3.5-turbo", 1000, 500)
# Verify usage tracking # Verify usage tracking
costs = controller.get_current_costs() costs = controller.get_current_costs()
@ -85,10 +91,10 @@ def test_cost_controller_history():
# Record one usage of each model # Record one usage of each model
assert controller.can_use_model("gpt-4", 1000, 500) assert controller.can_use_model("gpt-4", 1000, 500)
controller.record_usage("gpt-4", 1000, 500) assert controller.record_usage("gpt-4", 1000, 500)
assert controller.can_use_model("gpt-3.5-turbo", 1000, 500) assert controller.can_use_model("gpt-3.5-turbo", 1000, 500)
controller.record_usage("gpt-3.5-turbo", 1000, 500) assert controller.record_usage("gpt-3.5-turbo", 1000, 500)
# Check usage tracking # Check usage tracking
costs = controller.get_current_costs() costs = controller.get_current_costs()
@ -98,3 +104,23 @@ def test_cost_controller_history():
distribution = controller.get_model_distribution() distribution = controller.get_model_distribution()
assert abs(distribution["gpt-4"] - 0.5) < 0.1 assert abs(distribution["gpt-4"] - 0.5) < 0.1
assert abs(distribution["gpt-3.5-turbo"] - 0.5) < 0.1 assert abs(distribution["gpt-3.5-turbo"] - 0.5) < 0.1
def test_global_controller_budget_toggle():
"""Ensure global controller respects configured limits and budget flag."""
limits = CostLimits(hourly_budget=0.05, daily_budget=0.1, monthly_budget=1.0, model_weights={})
configure_cost_limits(limits)
controller = get_cost_controller()
assert budget_available() is True
within = controller.record_usage("gpt-4", 1000, 1000)
assert within is False # deliberately exceed tiny budget
assert budget_available() is False
# Reset controller state to avoid cross-test contamination
with controller._usage_lock: # type: ignore[attr-defined]
for bucket in controller._usage.values(): # type: ignore[attr-defined]
bucket.clear()
controller._last_cleanup = time.time() # type: ignore[attr-defined]
configure_cost_limits(CostLimits.default())

View File

@ -69,7 +69,7 @@ def test_prompt_template_format():
# Valid context # Valid context
result = template.format({"name": "World"}) result = template.format({"name": "World"})
assert result == "Hello Wor..." assert result == "Hello W..."
# Missing required context # Missing required context
with pytest.raises(ValueError) as exc: with pytest.raises(ValueError) as exc:
@ -77,8 +77,15 @@ def test_prompt_template_format():
assert "Missing required context" in str(exc.value) assert "Missing required context" in str(exc.value)
# Missing variable # Missing variable
template_no_required = PromptTemplate(
id="test2",
name="Test Template",
description="A test template",
template="Hello {name}!",
variables=["name"],
)
with pytest.raises(ValueError) as exc: with pytest.raises(ValueError) as exc:
template.format({"wrong": "value"}) template_no_required.format({"wrong": "value"})
assert "Missing template variable" in str(exc.value) assert "Missing template variable" in str(exc.value)
@ -95,7 +102,8 @@ def test_template_registry():
variables=["name"] variables=["name"]
) )
TemplateRegistry.register(template) TemplateRegistry.register(template)
assert TemplateRegistry.get("test") == template assert TemplateRegistry.get("test") is not None
assert TemplateRegistry.get_active_version("test") == "1.0.0"
# Register invalid template # Register invalid template
invalid = PromptTemplate( invalid = PromptTemplate(
@ -121,12 +129,17 @@ def test_template_registry():
"name": "JSON Test", "name": "JSON Test",
"description": "Test template from JSON", "description": "Test template from JSON",
"template": "Hello {name}!", "template": "Hello {name}!",
"variables": ["name"] "variables": ["name"],
"version": "2024.10",
"metadata": {"author": "qa"},
"activate": true
} }
} }
''' '''
TemplateRegistry.load_from_json(json_str) TemplateRegistry.load_from_json(json_str)
assert TemplateRegistry.get("json_test") is not None loaded = TemplateRegistry.get("json_test")
assert loaded is not None
assert TemplateRegistry.get_active_version("json_test") == "2024.10"
# Invalid JSON # Invalid JSON
with pytest.raises(ValueError) as exc: with pytest.raises(ValueError) as exc:
@ -141,7 +154,7 @@ def test_template_registry():
def test_default_templates(): def test_default_templates():
"""Test default template registration.""" """Test default template registration."""
TemplateRegistry.clear() TemplateRegistry.clear(reload_defaults=True)
from app.llm.templates import DEFAULT_TEMPLATES from app.llm.templates import DEFAULT_TEMPLATES
# Verify default templates are loaded # Verify default templates are loaded