update
This commit is contained in:
parent
a553de78b4
commit
b7283c859f
@ -119,6 +119,10 @@ class DepartmentAgent:
|
|||||||
if system_prompt:
|
if system_prompt:
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
prompt_body = department_prompt(self.settings, mutable_context)
|
prompt_body = department_prompt(self.settings, mutable_context)
|
||||||
|
template_meta = {}
|
||||||
|
raw_templates = mutable_context.raw.get("template_meta") if isinstance(mutable_context.raw, dict) else None
|
||||||
|
if isinstance(raw_templates, dict):
|
||||||
|
template_meta = dict(raw_templates.get(self.settings.code, {}))
|
||||||
prompt_checksum = hashlib.sha1(prompt_body.encode("utf-8")).hexdigest()
|
prompt_checksum = hashlib.sha1(prompt_body.encode("utf-8")).hexdigest()
|
||||||
prompt_preview = prompt_body[:240]
|
prompt_preview = prompt_body[:240]
|
||||||
messages.append({"role": "user", "content": prompt_body})
|
messages.append({"role": "user", "content": prompt_body})
|
||||||
@ -330,6 +334,7 @@ class DepartmentAgent:
|
|||||||
"instruction": self.settings.prompt,
|
"instruction": self.settings.prompt,
|
||||||
"system": system_prompt,
|
"system": system_prompt,
|
||||||
},
|
},
|
||||||
|
"template": template_meta,
|
||||||
"messages_exchanged": len(messages),
|
"messages_exchanged": len(messages),
|
||||||
"supplement_rounds": len(tool_call_records),
|
"supplement_rounds": len(tool_call_records),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
import time
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Any, Dict, Iterable, List, Optional
|
from typing import Any, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
@ -10,6 +11,7 @@ import requests
|
|||||||
|
|
||||||
from .context import ContextManager, Message
|
from .context import ContextManager, Message
|
||||||
from .templates import TemplateRegistry
|
from .templates import TemplateRegistry
|
||||||
|
from .cost import configure_cost_limits, get_cost_controller, budget_available
|
||||||
|
|
||||||
from app.utils.config import (
|
from app.utils.config import (
|
||||||
DEFAULT_LLM_BASE_URLS,
|
DEFAULT_LLM_BASE_URLS,
|
||||||
@ -200,6 +202,26 @@ def call_endpoint_with_messages(
|
|||||||
timeout = resolved["timeout"]
|
timeout = resolved["timeout"]
|
||||||
api_key = resolved["api_key"]
|
api_key = resolved["api_key"]
|
||||||
|
|
||||||
|
cfg = get_config()
|
||||||
|
cost_cfg = getattr(cfg, "llm_cost", None)
|
||||||
|
enforce_cost = False
|
||||||
|
cost_controller = None
|
||||||
|
if cost_cfg and getattr(cost_cfg, "enabled", False):
|
||||||
|
try:
|
||||||
|
limits = cost_cfg.to_cost_limits()
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
LOGGER.warning(
|
||||||
|
"成本控制配置解析失败,将忽略限制: %s",
|
||||||
|
exc,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
configure_cost_limits(limits)
|
||||||
|
enforce_cost = True
|
||||||
|
if not budget_available():
|
||||||
|
raise LLMError("LLM 调用预算已耗尽,请稍后重试。")
|
||||||
|
cost_controller = get_cost_controller()
|
||||||
|
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
"触发 LLM 请求:provider=%s model=%s base=%s",
|
"触发 LLM 请求:provider=%s model=%s base=%s",
|
||||||
provider_key,
|
provider_key,
|
||||||
@ -217,18 +239,24 @@ def call_endpoint_with_messages(
|
|||||||
"stream": False,
|
"stream": False,
|
||||||
"options": {"temperature": temperature},
|
"options": {"temperature": temperature},
|
||||||
}
|
}
|
||||||
|
start_time = time.perf_counter()
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{base_url.rstrip('/')}/api/chat",
|
f"{base_url.rstrip('/')}/api/chat",
|
||||||
json=payload,
|
json=payload,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
duration = time.perf_counter() - start_time
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}")
|
raise LLMError(f"Ollama 调用失败: {response.status_code} {response.text}")
|
||||||
record_call(provider_key, model)
|
data = response.json()
|
||||||
return response.json()
|
record_call(provider_key, model, duration=duration)
|
||||||
|
if enforce_cost and cost_controller:
|
||||||
|
cost_controller.record_usage(model or provider_key, 0, 0)
|
||||||
|
return data
|
||||||
|
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise LLMError(f"缺少 {provider_key} API Key (model={model})")
|
raise LLMError(f"缺少 {provider_key} API Key (model={model})")
|
||||||
|
start_time = time.perf_counter()
|
||||||
data = _request_openai_chat(
|
data = _request_openai_chat(
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
@ -239,10 +267,28 @@ def call_endpoint_with_messages(
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
|
duration = time.perf_counter() - start_time
|
||||||
usage = data.get("usage", {}) if isinstance(data, dict) else {}
|
usage = data.get("usage", {}) if isinstance(data, dict) else {}
|
||||||
prompt_tokens = usage.get("prompt_tokens") or usage.get("prompt_tokens_total")
|
prompt_tokens = usage.get("prompt_tokens") or usage.get("prompt_tokens_total")
|
||||||
completion_tokens = usage.get("completion_tokens") or usage.get("completion_tokens_total")
|
completion_tokens = usage.get("completion_tokens") or usage.get("completion_tokens_total")
|
||||||
record_call(provider_key, model, prompt_tokens, completion_tokens)
|
record_call(
|
||||||
|
provider_key,
|
||||||
|
model,
|
||||||
|
prompt_tokens,
|
||||||
|
completion_tokens,
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
if enforce_cost and cost_controller:
|
||||||
|
prompt_count = int(prompt_tokens or 0)
|
||||||
|
completion_count = int(completion_tokens or 0)
|
||||||
|
within_limits = cost_controller.record_usage(model or provider_key, prompt_count, completion_count)
|
||||||
|
if not within_limits:
|
||||||
|
LOGGER.warning(
|
||||||
|
"LLM 成本预算已超限:provider=%s model=%s",
|
||||||
|
provider_key,
|
||||||
|
model,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -90,10 +90,16 @@ class CostController:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def record_usage(self, model: str, prompt_tokens: int,
|
def record_usage(
|
||||||
completion_tokens: int) -> None:
|
self,
|
||||||
"""记录模型使用情况."""
|
model: str,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
) -> bool:
|
||||||
|
"""记录模型使用情况,并返回是否仍在预算范围内."""
|
||||||
|
|
||||||
cost = self._calculate_cost(model, prompt_tokens, completion_tokens)
|
cost = self._calculate_cost(model, prompt_tokens, completion_tokens)
|
||||||
|
within_limits = self._check_budget_limits(model, prompt_tokens, completion_tokens)
|
||||||
timestamp = time.time()
|
timestamp = time.time()
|
||||||
|
|
||||||
usage = {
|
usage = {
|
||||||
@ -112,6 +118,16 @@ class CostController:
|
|||||||
# 定期清理过期数据
|
# 定期清理过期数据
|
||||||
self._cleanup_old_usage(timestamp)
|
self._cleanup_old_usage(timestamp)
|
||||||
|
|
||||||
|
if not within_limits:
|
||||||
|
LOGGER.warning(
|
||||||
|
"Cost limit exceeded after recording usage - model: %s cost=$%.4f",
|
||||||
|
model,
|
||||||
|
cost,
|
||||||
|
extra=LOG_EXTRA,
|
||||||
|
)
|
||||||
|
|
||||||
|
return within_limits
|
||||||
|
|
||||||
def get_current_costs(self) -> Dict[str, float]:
|
def get_current_costs(self) -> Dict[str, float]:
|
||||||
"""获取当前时段的成本统计."""
|
"""获取当前时段的成本统计."""
|
||||||
with self._usage_lock:
|
with self._usage_lock:
|
||||||
@ -157,6 +173,16 @@ class CostController:
|
|||||||
for model, count in model_calls.items()
|
for model, count in model_calls.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def is_budget_available(self) -> bool:
|
||||||
|
"""判断当前预算是否允许继续调用LLM."""
|
||||||
|
|
||||||
|
costs = self.get_current_costs()
|
||||||
|
return (
|
||||||
|
costs["hourly"] < self.limits.hourly_budget and
|
||||||
|
costs["daily"] < self.limits.daily_budget and
|
||||||
|
costs["monthly"] < self.limits.monthly_budget
|
||||||
|
)
|
||||||
|
|
||||||
def _calculate_cost(self, model: str, prompt_tokens: int,
|
def _calculate_cost(self, model: str, prompt_tokens: int,
|
||||||
completion_tokens: int) -> float:
|
completion_tokens: int) -> float:
|
||||||
"""计算使用成本."""
|
"""计算使用成本."""
|
||||||
@ -226,13 +252,42 @@ class CostController:
|
|||||||
self._last_cleanup = current_time
|
self._last_cleanup = current_time
|
||||||
|
|
||||||
|
|
||||||
# 全局实例
|
# 全局实例管理
|
||||||
_controller = CostController()
|
_CONTROLLER_LOCK = threading.Lock()
|
||||||
|
_GLOBAL_CONTROLLER = CostController()
|
||||||
|
_LAST_LIMITS: Optional[CostLimits] = None
|
||||||
|
|
||||||
|
|
||||||
def get_controller() -> CostController:
|
def get_controller() -> CostController:
|
||||||
"""获取全局CostController实例."""
|
"""向后兼容的全局 CostController 访问方法."""
|
||||||
return _controller
|
|
||||||
|
return _GLOBAL_CONTROLLER
|
||||||
|
|
||||||
|
|
||||||
|
def get_cost_controller() -> CostController:
|
||||||
|
"""显式返回全局 CostController 实例."""
|
||||||
|
|
||||||
|
return _GLOBAL_CONTROLLER
|
||||||
|
|
||||||
|
|
||||||
def set_cost_limits(limits: CostLimits) -> None:
|
def set_cost_limits(limits: CostLimits) -> None:
|
||||||
"""设置全局成本限制."""
|
"""设置全局成本限制(兼容旧接口)。"""
|
||||||
_controller.limits = limits
|
|
||||||
|
configure_cost_limits(limits)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_cost_limits(limits: CostLimits) -> None:
|
||||||
|
"""设置全局成本限制,如果变更才更新。"""
|
||||||
|
|
||||||
|
global _LAST_LIMITS
|
||||||
|
with _CONTROLLER_LOCK:
|
||||||
|
if _LAST_LIMITS is None or limits != _LAST_LIMITS:
|
||||||
|
_GLOBAL_CONTROLLER.limits = limits
|
||||||
|
_LAST_LIMITS = limits
|
||||||
|
|
||||||
|
|
||||||
|
def budget_available() -> bool:
|
||||||
|
"""判断全局预算是否仍可用。"""
|
||||||
|
|
||||||
|
with _CONTROLLER_LOCK:
|
||||||
|
return _GLOBAL_CONTROLLER.is_budget_available()
|
||||||
|
|||||||
@ -17,6 +17,8 @@ class _Metrics:
|
|||||||
model_calls: Dict[str, int] = field(default_factory=dict)
|
model_calls: Dict[str, int] = field(default_factory=dict)
|
||||||
decisions: Deque[Dict[str, object]] = field(default_factory=lambda: deque(maxlen=500))
|
decisions: Deque[Dict[str, object]] = field(default_factory=lambda: deque(maxlen=500))
|
||||||
decision_action_counts: Dict[str, int] = field(default_factory=dict)
|
decision_action_counts: Dict[str, int] = field(default_factory=dict)
|
||||||
|
total_latency: float = 0.0
|
||||||
|
latency_samples: Deque[float] = field(default_factory=lambda: deque(maxlen=200))
|
||||||
|
|
||||||
|
|
||||||
_METRICS = _Metrics()
|
_METRICS = _Metrics()
|
||||||
@ -31,6 +33,8 @@ def record_call(
|
|||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
prompt_tokens: Optional[int] = None,
|
prompt_tokens: Optional[int] = None,
|
||||||
completion_tokens: Optional[int] = None,
|
completion_tokens: Optional[int] = None,
|
||||||
|
*,
|
||||||
|
duration: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Record a single LLM API invocation."""
|
"""Record a single LLM API invocation."""
|
||||||
|
|
||||||
@ -49,6 +53,10 @@ def record_call(
|
|||||||
_METRICS.total_prompt_tokens += int(prompt_tokens)
|
_METRICS.total_prompt_tokens += int(prompt_tokens)
|
||||||
if completion_tokens:
|
if completion_tokens:
|
||||||
_METRICS.total_completion_tokens += int(completion_tokens)
|
_METRICS.total_completion_tokens += int(completion_tokens)
|
||||||
|
if duration is not None:
|
||||||
|
duration_value = max(0.0, float(duration))
|
||||||
|
_METRICS.total_latency += duration_value
|
||||||
|
_METRICS.latency_samples.append(duration_value)
|
||||||
_notify_listeners()
|
_notify_listeners()
|
||||||
|
|
||||||
|
|
||||||
@ -64,6 +72,12 @@ def snapshot(reset: bool = False) -> Dict[str, object]:
|
|||||||
"model_calls": dict(_METRICS.model_calls),
|
"model_calls": dict(_METRICS.model_calls),
|
||||||
"decision_action_counts": dict(_METRICS.decision_action_counts),
|
"decision_action_counts": dict(_METRICS.decision_action_counts),
|
||||||
"recent_decisions": list(_METRICS.decisions),
|
"recent_decisions": list(_METRICS.decisions),
|
||||||
|
"average_latency": (
|
||||||
|
_METRICS.total_latency / _METRICS.total_calls
|
||||||
|
if _METRICS.total_calls
|
||||||
|
else 0.0
|
||||||
|
),
|
||||||
|
"latency_samples": list(_METRICS.latency_samples),
|
||||||
}
|
}
|
||||||
if reset:
|
if reset:
|
||||||
_METRICS.total_calls = 0
|
_METRICS.total_calls = 0
|
||||||
@ -73,6 +87,8 @@ def snapshot(reset: bool = False) -> Dict[str, object]:
|
|||||||
_METRICS.model_calls.clear()
|
_METRICS.model_calls.clear()
|
||||||
_METRICS.decision_action_counts.clear()
|
_METRICS.decision_action_counts.clear()
|
||||||
_METRICS.decisions.clear()
|
_METRICS.decisions.clear()
|
||||||
|
_METRICS.total_latency = 0.0
|
||||||
|
_METRICS.latency_samples.clear()
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
"""Prompt templates for natural language outputs."""
|
"""Prompt templates for natural language outputs."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Dict, TYPE_CHECKING
|
from typing import Dict, TYPE_CHECKING
|
||||||
|
|
||||||
from .templates import TemplateRegistry
|
from .templates import TemplateRegistry
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
from app.utils.config import DepartmentSettings
|
from app.utils.config import DepartmentSettings
|
||||||
from app.agents.departments import DepartmentContext
|
from app.agents.departments import DepartmentContext
|
||||||
@ -35,10 +38,48 @@ def department_prompt(
|
|||||||
role_description = settings.description.strip()
|
role_description = settings.description.strip()
|
||||||
role_instruction = settings.prompt.strip()
|
role_instruction = settings.prompt.strip()
|
||||||
|
|
||||||
# Determine template ID based on department settings
|
# Determine template ID and version
|
||||||
template_id = f"{settings.code.lower()}_dept"
|
template_id = (getattr(settings, "prompt_template_id", None) or f"{settings.code.lower()}_dept").strip()
|
||||||
if not TemplateRegistry.get(template_id):
|
requested_version = getattr(settings, "prompt_template_version", None)
|
||||||
|
original_requested_version = requested_version
|
||||||
|
template = TemplateRegistry.get(template_id, version=requested_version)
|
||||||
|
applied_version = requested_version if template and requested_version else None
|
||||||
|
|
||||||
|
if not template:
|
||||||
|
if requested_version:
|
||||||
|
LOGGER.warning(
|
||||||
|
"Template %s version %s not found, falling back to active version",
|
||||||
|
template_id,
|
||||||
|
requested_version,
|
||||||
|
)
|
||||||
|
template = TemplateRegistry.get(template_id)
|
||||||
|
applied_version = TemplateRegistry.get_active_version(template_id)
|
||||||
|
|
||||||
|
if not template:
|
||||||
|
LOGGER.warning(
|
||||||
|
"Template %s unavailable, using department_base fallback",
|
||||||
|
template_id,
|
||||||
|
)
|
||||||
template_id = "department_base"
|
template_id = "department_base"
|
||||||
|
template = TemplateRegistry.get(template_id)
|
||||||
|
requested_version = None
|
||||||
|
applied_version = TemplateRegistry.get_active_version(template_id)
|
||||||
|
|
||||||
|
if not template:
|
||||||
|
raise ValueError("No prompt template available for department prompts")
|
||||||
|
|
||||||
|
if applied_version is None:
|
||||||
|
applied_version = TemplateRegistry.get_active_version(template_id)
|
||||||
|
template_meta = {
|
||||||
|
"template_id": template_id,
|
||||||
|
"requested_version": original_requested_version,
|
||||||
|
"applied_version": applied_version,
|
||||||
|
}
|
||||||
|
|
||||||
|
raw_container = getattr(context, "raw", None)
|
||||||
|
if isinstance(raw_container, dict):
|
||||||
|
meta_store = raw_container.setdefault("template_meta", {})
|
||||||
|
meta_store[settings.code] = template_meta
|
||||||
|
|
||||||
# Prepare template variables
|
# Prepare template variables
|
||||||
template_vars = {
|
template_vars = {
|
||||||
@ -54,5 +95,4 @@ def department_prompt(
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Get template and format prompt
|
# Get template and format prompt
|
||||||
template = TemplateRegistry.get(template_id)
|
|
||||||
return template.format(template_vars)
|
return template.format(template_vars)
|
||||||
|
|||||||
@ -4,7 +4,10 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
|
from .version import TemplateVersionManager
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -69,31 +72,139 @@ class PromptTemplate:
|
|||||||
|
|
||||||
|
|
||||||
class TemplateRegistry:
|
class TemplateRegistry:
|
||||||
"""Global registry for prompt templates."""
|
"""Global registry for prompt templates with version awareness."""
|
||||||
|
|
||||||
_templates: Dict[str, PromptTemplate] = {}
|
_templates: Dict[str, PromptTemplate] = {}
|
||||||
|
_version_manager: Optional["TemplateVersionManager"] = None
|
||||||
|
_default_version_label: str = "1.0.0"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(cls, template: PromptTemplate) -> None:
|
def _manager(cls) -> "TemplateVersionManager":
|
||||||
"""Register a new template."""
|
if cls._version_manager is None:
|
||||||
|
from .version import TemplateVersionManager # Local import to avoid circular dependency
|
||||||
|
|
||||||
|
cls._version_manager = TemplateVersionManager()
|
||||||
|
return cls._version_manager
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(
|
||||||
|
cls,
|
||||||
|
template: PromptTemplate,
|
||||||
|
*,
|
||||||
|
version: Optional[str] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
activate: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Register a new template and optionally version it."""
|
||||||
|
|
||||||
errors = template.validate()
|
errors = template.validate()
|
||||||
if errors:
|
if errors:
|
||||||
raise ValueError(f"Invalid template {template.id}: {'; '.join(errors)}")
|
raise ValueError(f"Invalid template {template.id}: {'; '.join(errors)}")
|
||||||
|
|
||||||
cls._templates[template.id] = template
|
cls._templates[template.id] = template
|
||||||
|
|
||||||
|
manager = cls._manager()
|
||||||
|
existing_versions = manager.list_versions(template.id)
|
||||||
|
resolved_metadata: Dict[str, Any] = dict(metadata or {})
|
||||||
|
if version:
|
||||||
|
manager.add_version(
|
||||||
|
template,
|
||||||
|
version,
|
||||||
|
metadata=resolved_metadata or None,
|
||||||
|
activate=activate,
|
||||||
|
)
|
||||||
|
elif not existing_versions:
|
||||||
|
if "source" not in resolved_metadata:
|
||||||
|
resolved_metadata["source"] = "default"
|
||||||
|
manager.add_version(
|
||||||
|
template,
|
||||||
|
cls._default_version_label,
|
||||||
|
metadata=resolved_metadata,
|
||||||
|
activate=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, template_id: str) -> Optional[PromptTemplate]:
|
def register_version(
|
||||||
"""Get template by ID."""
|
cls,
|
||||||
|
template_id: str,
|
||||||
|
*,
|
||||||
|
version: str,
|
||||||
|
template: Optional[PromptTemplate] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
activate: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Register an additional version for an existing template."""
|
||||||
|
|
||||||
|
base_template = template or cls._templates.get(template_id)
|
||||||
|
if not base_template:
|
||||||
|
raise ValueError(f"Template {template_id} not found for version registration")
|
||||||
|
|
||||||
|
manager = cls._manager()
|
||||||
|
manager.add_version(
|
||||||
|
base_template,
|
||||||
|
version,
|
||||||
|
metadata=metadata,
|
||||||
|
activate=activate,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def activate_version(cls, template_id: str, version: str) -> None:
|
||||||
|
"""Activate a specific template version."""
|
||||||
|
|
||||||
|
manager = cls._manager()
|
||||||
|
manager.activate_version(template_id, version)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(
|
||||||
|
cls,
|
||||||
|
template_id: str,
|
||||||
|
*,
|
||||||
|
version: Optional[str] = None,
|
||||||
|
) -> Optional[PromptTemplate]:
|
||||||
|
"""Get template by ID and optional version."""
|
||||||
|
|
||||||
|
manager = cls._manager()
|
||||||
|
if version:
|
||||||
|
stored = manager.get_version(template_id, version)
|
||||||
|
if stored:
|
||||||
|
return stored.template
|
||||||
|
|
||||||
|
active = manager.get_active_version(template_id)
|
||||||
|
if active:
|
||||||
|
return active.template
|
||||||
|
|
||||||
return cls._templates.get(template_id)
|
return cls._templates.get(template_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_active_version(cls, template_id: str) -> Optional[str]:
|
||||||
|
"""Return the currently active version label for a template."""
|
||||||
|
|
||||||
|
manager = cls._manager()
|
||||||
|
active = manager.get_active_version(template_id)
|
||||||
|
return active.version if active else None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def list(cls) -> List[PromptTemplate]:
|
def list(cls) -> List[PromptTemplate]:
|
||||||
"""List all registered templates."""
|
"""List all registered templates (active versions preferred)."""
|
||||||
return list(cls._templates.values())
|
|
||||||
|
collected: Dict[str, PromptTemplate] = {}
|
||||||
|
manager = cls._manager()
|
||||||
|
for template_id, template in cls._templates.items():
|
||||||
|
active = manager.get_active_version(template_id)
|
||||||
|
collected[template_id] = active.template if active else template
|
||||||
|
return list(collected.values())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_versions(cls, template_id: str) -> List[str]:
|
||||||
|
"""List available version labels for a template."""
|
||||||
|
|
||||||
|
manager = cls._manager()
|
||||||
|
return [ver.version for ver in manager.list_versions(template_id)]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_from_json(cls, json_str: str) -> None:
|
def load_from_json(cls, json_str: str) -> None:
|
||||||
"""Load templates from JSON string."""
|
"""Load templates from JSON string."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(json_str)
|
data = json.loads(json_str)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
@ -103,6 +214,13 @@ class TemplateRegistry:
|
|||||||
raise ValueError("JSON root must be an object")
|
raise ValueError("JSON root must be an object")
|
||||||
|
|
||||||
for template_id, cfg in data.items():
|
for template_id, cfg in data.items():
|
||||||
|
if not isinstance(cfg, dict):
|
||||||
|
raise ValueError(f"Template {template_id} configuration must be an object")
|
||||||
|
version = cfg.get("version")
|
||||||
|
metadata = cfg.get("metadata")
|
||||||
|
if metadata is not None and not isinstance(metadata, dict):
|
||||||
|
raise ValueError(f"Template {template_id} metadata must be an object")
|
||||||
|
activate = bool(cfg.get("activate", False))
|
||||||
template = PromptTemplate(
|
template = PromptTemplate(
|
||||||
id=template_id,
|
id=template_id,
|
||||||
name=cfg.get("name", template_id),
|
name=cfg.get("name", template_id),
|
||||||
@ -113,12 +231,21 @@ class TemplateRegistry:
|
|||||||
required_context=cfg.get("required_context", []),
|
required_context=cfg.get("required_context", []),
|
||||||
validation_rules=cfg.get("validation_rules", [])
|
validation_rules=cfg.get("validation_rules", [])
|
||||||
)
|
)
|
||||||
cls.register(template)
|
cls.register(
|
||||||
|
template,
|
||||||
|
version=version,
|
||||||
|
metadata=metadata,
|
||||||
|
activate=activate,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def clear(cls) -> None:
|
def clear(cls, *, reload_defaults: bool = False) -> None:
|
||||||
"""Clear all registered templates."""
|
"""Clear all registered templates and optionally reload defaults."""
|
||||||
|
|
||||||
cls._templates.clear()
|
cls._templates.clear()
|
||||||
|
cls._version_manager = None
|
||||||
|
if reload_defaults:
|
||||||
|
register_default_templates()
|
||||||
|
|
||||||
|
|
||||||
# Default template definitions
|
# Default template definitions
|
||||||
@ -234,7 +361,19 @@ def register_default_templates() -> None:
|
|||||||
"validation_rules": cfg.get("validation_rules", [])
|
"validation_rules": cfg.get("validation_rules", [])
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
TemplateRegistry.register(PromptTemplate(**template_config))
|
template = PromptTemplate(**template_config)
|
||||||
|
version_label = str(
|
||||||
|
cfg.get("version") or TemplateRegistry._default_version_label
|
||||||
|
)
|
||||||
|
metadata_raw = cfg.get("metadata")
|
||||||
|
metadata = dict(metadata_raw) if isinstance(metadata_raw, dict) else {}
|
||||||
|
metadata.setdefault("source", "defaults")
|
||||||
|
TemplateRegistry.register(
|
||||||
|
template,
|
||||||
|
version=version_label,
|
||||||
|
metadata=metadata,
|
||||||
|
activate=cfg.get("activate", True),
|
||||||
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logging.warning(f"Failed to register template {template_id}: {e}")
|
logging.warning(f"Failed to register template {template_id}: {e}")
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,13 @@ def _default_root() -> Path:
|
|||||||
return Path(__file__).resolve().parents[2] / "app" / "data"
|
return Path(__file__).resolve().parents[2] / "app" / "data"
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_float(value: object, fallback: float) -> float:
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataPaths:
|
class DataPaths:
|
||||||
"""Holds filesystem locations for persistent artifacts."""
|
"""Holds filesystem locations for persistent artifacts."""
|
||||||
@ -198,8 +205,61 @@ class LLMEndpoint:
|
|||||||
self.provider = (self.provider or "ollama").lower()
|
self.provider = (self.provider or "ollama").lower()
|
||||||
if self.temperature is not None:
|
if self.temperature is not None:
|
||||||
self.temperature = float(self.temperature)
|
self.temperature = float(self.temperature)
|
||||||
if self.timeout is not None:
|
|
||||||
self.timeout = float(self.timeout)
|
|
||||||
|
@dataclass
|
||||||
|
class LLMCostSettings:
|
||||||
|
"""Configurable budgets and weights for LLM cost control."""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
hourly_budget: float = 5.0
|
||||||
|
daily_budget: float = 50.0
|
||||||
|
monthly_budget: float = 500.0
|
||||||
|
model_weights: Dict[str, float] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def update_from_dict(self, data: Mapping[str, object]) -> None:
|
||||||
|
if "enabled" in data:
|
||||||
|
self.enabled = bool(data.get("enabled"))
|
||||||
|
if "hourly_budget" in data:
|
||||||
|
self.hourly_budget = _safe_float(data.get("hourly_budget"), self.hourly_budget)
|
||||||
|
if "daily_budget" in data:
|
||||||
|
self.daily_budget = _safe_float(data.get("daily_budget"), self.daily_budget)
|
||||||
|
if "monthly_budget" in data:
|
||||||
|
self.monthly_budget = _safe_float(data.get("monthly_budget"), self.monthly_budget)
|
||||||
|
weights = data.get("model_weights") if isinstance(data, Mapping) else None
|
||||||
|
if isinstance(weights, Mapping):
|
||||||
|
normalized: Dict[str, float] = {}
|
||||||
|
for key, value in weights.items():
|
||||||
|
try:
|
||||||
|
normalized[str(key)] = float(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
if normalized:
|
||||||
|
self.model_weights = normalized
|
||||||
|
|
||||||
|
def to_cost_limits(self):
|
||||||
|
"""Convert into runtime `CostLimits` descriptor."""
|
||||||
|
|
||||||
|
from app.llm.cost import CostLimits # Imported lazily to avoid cycles
|
||||||
|
|
||||||
|
weights: Dict[str, float] = {}
|
||||||
|
for key, value in (self.model_weights or {}).items():
|
||||||
|
try:
|
||||||
|
weights[str(key)] = float(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
return CostLimits(
|
||||||
|
hourly_budget=float(self.hourly_budget),
|
||||||
|
daily_budget=float(self.daily_budget),
|
||||||
|
monthly_budget=float(self.monthly_budget),
|
||||||
|
model_weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Mapping[str, object]) -> "LLMCostSettings":
|
||||||
|
inst = cls()
|
||||||
|
inst.update_from_dict(data)
|
||||||
|
return inst
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -241,6 +301,8 @@ class DepartmentSettings:
|
|||||||
data_scope: List[str] = field(default_factory=list)
|
data_scope: List[str] = field(default_factory=list)
|
||||||
prompt: str = ""
|
prompt: str = ""
|
||||||
llm: LLMConfig = field(default_factory=LLMConfig)
|
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||||
|
prompt_template_id: Optional[str] = None
|
||||||
|
prompt_template_version: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def _default_departments() -> Dict[str, DepartmentSettings]:
|
def _default_departments() -> Dict[str, DepartmentSettings]:
|
||||||
@ -327,6 +389,7 @@ def _default_departments() -> Dict[str, DepartmentSettings]:
|
|||||||
description=item.get("description", ""),
|
description=item.get("description", ""),
|
||||||
data_scope=list(item.get("data_scope", [])),
|
data_scope=list(item.get("data_scope", [])),
|
||||||
prompt=item.get("prompt", ""),
|
prompt=item.get("prompt", ""),
|
||||||
|
prompt_template_id=f"{item['code']}_dept",
|
||||||
)
|
)
|
||||||
for item in presets
|
for item in presets
|
||||||
}
|
}
|
||||||
@ -355,6 +418,7 @@ class AppConfig:
|
|||||||
data_update_interval: int = 7 # 数据更新间隔(天)
|
data_update_interval: int = 7 # 数据更新间隔(天)
|
||||||
llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers)
|
llm_providers: Dict[str, LLMProvider] = field(default_factory=_default_llm_providers)
|
||||||
llm: LLMConfig = field(default_factory=LLMConfig)
|
llm: LLMConfig = field(default_factory=LLMConfig)
|
||||||
|
llm_cost: LLMCostSettings = field(default_factory=LLMCostSettings)
|
||||||
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
|
departments: Dict[str, DepartmentSettings] = field(default_factory=_default_departments)
|
||||||
portfolio: PortfolioSettings = field(default_factory=PortfolioSettings)
|
portfolio: PortfolioSettings = field(default_factory=PortfolioSettings)
|
||||||
|
|
||||||
@ -463,6 +527,10 @@ def _load_from_file(cfg: AppConfig) -> None:
|
|||||||
)
|
)
|
||||||
cfg.portfolio = updated_portfolio
|
cfg.portfolio = updated_portfolio
|
||||||
|
|
||||||
|
cost_payload = payload.get("llm_cost")
|
||||||
|
if isinstance(cost_payload, dict):
|
||||||
|
cfg.llm_cost.update_from_dict(cost_payload)
|
||||||
|
|
||||||
legacy_profiles: Dict[str, Dict[str, object]] = {}
|
legacy_profiles: Dict[str, Dict[str, object]] = {}
|
||||||
legacy_routes: Dict[str, Dict[str, object]] = {}
|
legacy_routes: Dict[str, Dict[str, object]] = {}
|
||||||
|
|
||||||
@ -574,6 +642,7 @@ def _load_from_file(cfg: AppConfig) -> None:
|
|||||||
for code, data in departments_payload.items():
|
for code, data in departments_payload.items():
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
continue
|
continue
|
||||||
|
current_setting = cfg.departments.get(code)
|
||||||
title = data.get("title") or code
|
title = data.get("title") or code
|
||||||
description = data.get("description") or ""
|
description = data.get("description") or ""
|
||||||
weight = float(data.get("weight", 1.0))
|
weight = float(data.get("weight", 1.0))
|
||||||
@ -606,6 +675,33 @@ def _load_from_file(cfg: AppConfig) -> None:
|
|||||||
if isinstance(majority_raw, int) and majority_raw > 0:
|
if isinstance(majority_raw, int) and majority_raw > 0:
|
||||||
llm_cfg.majority_threshold = majority_raw
|
llm_cfg.majority_threshold = majority_raw
|
||||||
resolved_cfg = llm_cfg
|
resolved_cfg = llm_cfg
|
||||||
|
template_id_raw = data.get("prompt_template_id")
|
||||||
|
if isinstance(template_id_raw, str):
|
||||||
|
template_id_candidate = template_id_raw.strip()
|
||||||
|
elif template_id_raw is not None:
|
||||||
|
template_id_candidate = str(template_id_raw).strip()
|
||||||
|
else:
|
||||||
|
template_id_candidate = ""
|
||||||
|
if template_id_candidate:
|
||||||
|
template_id = template_id_candidate
|
||||||
|
elif current_setting and current_setting.prompt_template_id:
|
||||||
|
template_id = current_setting.prompt_template_id
|
||||||
|
else:
|
||||||
|
template_id = f"{code}_dept"
|
||||||
|
|
||||||
|
template_version_raw = data.get("prompt_template_version")
|
||||||
|
if isinstance(template_version_raw, str):
|
||||||
|
template_version_candidate = template_version_raw.strip()
|
||||||
|
elif template_version_raw is not None:
|
||||||
|
template_version_candidate = str(template_version_raw).strip()
|
||||||
|
else:
|
||||||
|
template_version_candidate = ""
|
||||||
|
if template_version_candidate:
|
||||||
|
template_version = template_version_candidate
|
||||||
|
elif current_setting:
|
||||||
|
template_version = current_setting.prompt_template_version
|
||||||
|
else:
|
||||||
|
template_version = None
|
||||||
new_departments[code] = DepartmentSettings(
|
new_departments[code] = DepartmentSettings(
|
||||||
code=code,
|
code=code,
|
||||||
title=title,
|
title=title,
|
||||||
@ -614,6 +710,8 @@ def _load_from_file(cfg: AppConfig) -> None:
|
|||||||
data_scope=data_scope,
|
data_scope=data_scope,
|
||||||
prompt=prompt_text,
|
prompt=prompt_text,
|
||||||
llm=resolved_cfg,
|
llm=resolved_cfg,
|
||||||
|
prompt_template_id=template_id,
|
||||||
|
prompt_template_version=template_version,
|
||||||
)
|
)
|
||||||
if new_departments:
|
if new_departments:
|
||||||
cfg.departments = new_departments
|
cfg.departments = new_departments
|
||||||
@ -648,6 +746,13 @@ def save_config(cfg: AppConfig | None = None) -> None:
|
|||||||
"primary": _endpoint_to_dict(cfg.llm.primary),
|
"primary": _endpoint_to_dict(cfg.llm.primary),
|
||||||
"ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble],
|
"ensemble": [_endpoint_to_dict(ep) for ep in cfg.llm.ensemble],
|
||||||
},
|
},
|
||||||
|
"llm_cost": {
|
||||||
|
"enabled": cfg.llm_cost.enabled,
|
||||||
|
"hourly_budget": cfg.llm_cost.hourly_budget,
|
||||||
|
"daily_budget": cfg.llm_cost.daily_budget,
|
||||||
|
"monthly_budget": cfg.llm_cost.monthly_budget,
|
||||||
|
"model_weights": cfg.llm_cost.model_weights,
|
||||||
|
},
|
||||||
"llm_providers": {
|
"llm_providers": {
|
||||||
key: provider.to_dict()
|
key: provider.to_dict()
|
||||||
for key, provider in cfg.llm_providers.items()
|
for key, provider in cfg.llm_providers.items()
|
||||||
@ -659,6 +764,8 @@ def save_config(cfg: AppConfig | None = None) -> None:
|
|||||||
"weight": dept.weight,
|
"weight": dept.weight,
|
||||||
"data_scope": list(dept.data_scope),
|
"data_scope": list(dept.data_scope),
|
||||||
"prompt": dept.prompt,
|
"prompt": dept.prompt,
|
||||||
|
"prompt_template_id": dept.prompt_template_id,
|
||||||
|
"prompt_template_version": dept.prompt_template_version,
|
||||||
"llm": {
|
"llm": {
|
||||||
"strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
"strategy": dept.llm.strategy if dept.llm.strategy in ALLOWED_LLM_STRATEGIES else "single",
|
||||||
"majority_threshold": dept.llm.majority_threshold,
|
"majority_threshold": dept.llm.majority_threshold,
|
||||||
|
|||||||
@ -1,8 +1,14 @@
|
|||||||
"""Test cases for LLM cost control system."""
|
"""Test cases for LLM cost control system."""
|
||||||
import pytest
|
import time
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from app.llm.cost import CostLimits, ModelCosts, CostController
|
from app.llm.cost import (
|
||||||
|
CostLimits,
|
||||||
|
ModelCosts,
|
||||||
|
CostController,
|
||||||
|
configure_cost_limits,
|
||||||
|
get_cost_controller,
|
||||||
|
budget_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_cost_limits():
|
def test_cost_limits():
|
||||||
@ -53,11 +59,11 @@ def test_cost_controller():
|
|||||||
assert controller.can_use_model("gpt-4", 1000, 500)
|
assert controller.can_use_model("gpt-4", 1000, 500)
|
||||||
|
|
||||||
# Then record the usage
|
# Then record the usage
|
||||||
controller.record_usage("gpt-4", 1000, 500) # About $0.09
|
assert controller.record_usage("gpt-4", 1000, 500) is True # About $0.09
|
||||||
|
|
||||||
# Record usage for second model to maintain weight balance
|
# Record usage for second model to maintain weight balance
|
||||||
assert controller.can_use_model("gpt-3.5-turbo", 1000, 500)
|
assert controller.can_use_model("gpt-3.5-turbo", 1000, 500)
|
||||||
controller.record_usage("gpt-3.5-turbo", 1000, 500)
|
assert controller.record_usage("gpt-3.5-turbo", 1000, 500)
|
||||||
|
|
||||||
# Verify usage tracking
|
# Verify usage tracking
|
||||||
costs = controller.get_current_costs()
|
costs = controller.get_current_costs()
|
||||||
@ -85,10 +91,10 @@ def test_cost_controller_history():
|
|||||||
|
|
||||||
# Record one usage of each model
|
# Record one usage of each model
|
||||||
assert controller.can_use_model("gpt-4", 1000, 500)
|
assert controller.can_use_model("gpt-4", 1000, 500)
|
||||||
controller.record_usage("gpt-4", 1000, 500)
|
assert controller.record_usage("gpt-4", 1000, 500)
|
||||||
|
|
||||||
assert controller.can_use_model("gpt-3.5-turbo", 1000, 500)
|
assert controller.can_use_model("gpt-3.5-turbo", 1000, 500)
|
||||||
controller.record_usage("gpt-3.5-turbo", 1000, 500)
|
assert controller.record_usage("gpt-3.5-turbo", 1000, 500)
|
||||||
|
|
||||||
# Check usage tracking
|
# Check usage tracking
|
||||||
costs = controller.get_current_costs()
|
costs = controller.get_current_costs()
|
||||||
@ -98,3 +104,23 @@ def test_cost_controller_history():
|
|||||||
distribution = controller.get_model_distribution()
|
distribution = controller.get_model_distribution()
|
||||||
assert abs(distribution["gpt-4"] - 0.5) < 0.1
|
assert abs(distribution["gpt-4"] - 0.5) < 0.1
|
||||||
assert abs(distribution["gpt-3.5-turbo"] - 0.5) < 0.1
|
assert abs(distribution["gpt-3.5-turbo"] - 0.5) < 0.1
|
||||||
|
|
||||||
|
|
||||||
|
def test_global_controller_budget_toggle():
|
||||||
|
"""Ensure global controller respects configured limits and budget flag."""
|
||||||
|
limits = CostLimits(hourly_budget=0.05, daily_budget=0.1, monthly_budget=1.0, model_weights={})
|
||||||
|
configure_cost_limits(limits)
|
||||||
|
controller = get_cost_controller()
|
||||||
|
|
||||||
|
assert budget_available() is True
|
||||||
|
within = controller.record_usage("gpt-4", 1000, 1000)
|
||||||
|
assert within is False # deliberately exceed tiny budget
|
||||||
|
assert budget_available() is False
|
||||||
|
|
||||||
|
# Reset controller state to avoid cross-test contamination
|
||||||
|
with controller._usage_lock: # type: ignore[attr-defined]
|
||||||
|
for bucket in controller._usage.values(): # type: ignore[attr-defined]
|
||||||
|
bucket.clear()
|
||||||
|
controller._last_cleanup = time.time() # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
configure_cost_limits(CostLimits.default())
|
||||||
|
|||||||
@ -69,7 +69,7 @@ def test_prompt_template_format():
|
|||||||
|
|
||||||
# Valid context
|
# Valid context
|
||||||
result = template.format({"name": "World"})
|
result = template.format({"name": "World"})
|
||||||
assert result == "Hello Wor..."
|
assert result == "Hello W..."
|
||||||
|
|
||||||
# Missing required context
|
# Missing required context
|
||||||
with pytest.raises(ValueError) as exc:
|
with pytest.raises(ValueError) as exc:
|
||||||
@ -77,8 +77,15 @@ def test_prompt_template_format():
|
|||||||
assert "Missing required context" in str(exc.value)
|
assert "Missing required context" in str(exc.value)
|
||||||
|
|
||||||
# Missing variable
|
# Missing variable
|
||||||
|
template_no_required = PromptTemplate(
|
||||||
|
id="test2",
|
||||||
|
name="Test Template",
|
||||||
|
description="A test template",
|
||||||
|
template="Hello {name}!",
|
||||||
|
variables=["name"],
|
||||||
|
)
|
||||||
with pytest.raises(ValueError) as exc:
|
with pytest.raises(ValueError) as exc:
|
||||||
template.format({"wrong": "value"})
|
template_no_required.format({"wrong": "value"})
|
||||||
assert "Missing template variable" in str(exc.value)
|
assert "Missing template variable" in str(exc.value)
|
||||||
|
|
||||||
|
|
||||||
@ -95,7 +102,8 @@ def test_template_registry():
|
|||||||
variables=["name"]
|
variables=["name"]
|
||||||
)
|
)
|
||||||
TemplateRegistry.register(template)
|
TemplateRegistry.register(template)
|
||||||
assert TemplateRegistry.get("test") == template
|
assert TemplateRegistry.get("test") is not None
|
||||||
|
assert TemplateRegistry.get_active_version("test") == "1.0.0"
|
||||||
|
|
||||||
# Register invalid template
|
# Register invalid template
|
||||||
invalid = PromptTemplate(
|
invalid = PromptTemplate(
|
||||||
@ -121,12 +129,17 @@ def test_template_registry():
|
|||||||
"name": "JSON Test",
|
"name": "JSON Test",
|
||||||
"description": "Test template from JSON",
|
"description": "Test template from JSON",
|
||||||
"template": "Hello {name}!",
|
"template": "Hello {name}!",
|
||||||
"variables": ["name"]
|
"variables": ["name"],
|
||||||
|
"version": "2024.10",
|
||||||
|
"metadata": {"author": "qa"},
|
||||||
|
"activate": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
TemplateRegistry.load_from_json(json_str)
|
TemplateRegistry.load_from_json(json_str)
|
||||||
assert TemplateRegistry.get("json_test") is not None
|
loaded = TemplateRegistry.get("json_test")
|
||||||
|
assert loaded is not None
|
||||||
|
assert TemplateRegistry.get_active_version("json_test") == "2024.10"
|
||||||
|
|
||||||
# Invalid JSON
|
# Invalid JSON
|
||||||
with pytest.raises(ValueError) as exc:
|
with pytest.raises(ValueError) as exc:
|
||||||
@ -141,7 +154,7 @@ def test_template_registry():
|
|||||||
|
|
||||||
def test_default_templates():
|
def test_default_templates():
|
||||||
"""Test default template registration."""
|
"""Test default template registration."""
|
||||||
TemplateRegistry.clear()
|
TemplateRegistry.clear(reload_defaults=True)
|
||||||
from app.llm.templates import DEFAULT_TEMPLATES
|
from app.llm.templates import DEFAULT_TEMPLATES
|
||||||
|
|
||||||
# Verify default templates are loaded
|
# Verify default templates are loaded
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user