update
This commit is contained in:
parent
dd6e51400e
commit
b4bd9fc9c5
238
app/llm/cost.py
Normal file
238
app/llm/cost.py
Normal file
@ -0,0 +1,238 @@
|
||||
"""LLM cost control and budget management."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from .metrics import snapshot
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
LOG_EXTRA = {"stage": "cost_control"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CostLimits:
|
||||
"""Cost control limits configuration."""
|
||||
|
||||
hourly_budget: float # 每小时预算
|
||||
daily_budget: float # 每日预算
|
||||
monthly_budget: float # 每月预算
|
||||
model_weights: Dict[str, float] = field(default_factory=dict) # 模型权重配置
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> CostLimits:
|
||||
"""Create default cost limits."""
|
||||
return cls(
|
||||
hourly_budget=2.0, # $2/hour
|
||||
daily_budget=20.0, # $20/day
|
||||
monthly_budget=300.0, # $300/month
|
||||
model_weights={
|
||||
"gpt-4": 0.2, # 限制GPT-4使用比例
|
||||
"gpt-3.5-turbo": 0.6,
|
||||
"llama2": 0.2
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCosts:
|
||||
"""Per-model cost configuration."""
|
||||
|
||||
prompt_cost_per_1k: float
|
||||
completion_cost_per_1k: float
|
||||
min_tokens: int = 1
|
||||
|
||||
def calculate(self, prompt_tokens: int, completion_tokens: int) -> float:
|
||||
"""Calculate cost for token usage."""
|
||||
prompt_cost = max(self.min_tokens, prompt_tokens) / 1000 * self.prompt_cost_per_1k
|
||||
completion_cost = max(self.min_tokens, completion_tokens) / 1000 * self.completion_cost_per_1k
|
||||
return prompt_cost + completion_cost
|
||||
|
||||
|
||||
class CostController:
|
||||
"""Controls and manages LLM costs."""
|
||||
|
||||
def __init__(self, limits: Optional[CostLimits] = None):
|
||||
"""Initialize cost controller."""
|
||||
self.limits = limits or CostLimits.default()
|
||||
self._costs: Dict[str, ModelCosts] = {
|
||||
"gpt-4": ModelCosts(0.03, 0.06),
|
||||
"gpt-4-32k": ModelCosts(0.06, 0.12),
|
||||
"gpt-3.5-turbo": ModelCosts(0.0015, 0.002),
|
||||
"gpt-3.5-turbo-16k": ModelCosts(0.003, 0.004),
|
||||
"llama2": ModelCosts(0.0, 0.0),
|
||||
"codellama": ModelCosts(0.0, 0.0)
|
||||
}
|
||||
self._usage_lock = threading.Lock()
|
||||
self._usage: Dict[str, List[Dict[str, Any]]] = {
|
||||
"hourly": [],
|
||||
"daily": [],
|
||||
"monthly": []
|
||||
}
|
||||
self._last_cleanup = time.time()
|
||||
self._cleanup_interval = 3600 # 1小时清理一次历史数据
|
||||
|
||||
def can_use_model(self, model: str, prompt_tokens: int,
|
||||
completion_tokens: int) -> bool:
|
||||
"""检查是否允许使用指定模型."""
|
||||
# 检查成本限制
|
||||
if not self._check_budget_limits(model, prompt_tokens, completion_tokens):
|
||||
return False
|
||||
|
||||
# 检查模型权重限制
|
||||
if not self._check_model_weights(model):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def record_usage(self, model: str, prompt_tokens: int,
|
||||
completion_tokens: int) -> None:
|
||||
"""记录模型使用情况."""
|
||||
cost = self._calculate_cost(model, prompt_tokens, completion_tokens)
|
||||
timestamp = time.time()
|
||||
|
||||
usage = {
|
||||
"model": model,
|
||||
"timestamp": timestamp,
|
||||
"cost": cost,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens
|
||||
}
|
||||
|
||||
with self._usage_lock:
|
||||
self._usage["hourly"].append(usage)
|
||||
self._usage["daily"].append(usage)
|
||||
self._usage["monthly"].append(usage)
|
||||
|
||||
# 定期清理过期数据
|
||||
self._cleanup_old_usage(timestamp)
|
||||
|
||||
def get_current_costs(self) -> Dict[str, float]:
|
||||
"""获取当前时段的成本统计."""
|
||||
with self._usage_lock:
|
||||
now = time.time()
|
||||
hour_ago = now - 3600
|
||||
day_ago = now - 86400
|
||||
month_ago = now - 2592000 # 30天
|
||||
|
||||
hourly = sum(u["cost"] for u in self._usage["hourly"]
|
||||
if u["timestamp"] > hour_ago)
|
||||
daily = sum(u["cost"] for u in self._usage["daily"]
|
||||
if u["timestamp"] > day_ago)
|
||||
monthly = sum(u["cost"] for u in self._usage["monthly"]
|
||||
if u["timestamp"] > month_ago)
|
||||
|
||||
return {
|
||||
"hourly": hourly,
|
||||
"daily": daily,
|
||||
"monthly": monthly
|
||||
}
|
||||
|
||||
def get_model_distribution(self) -> Dict[str, float]:
|
||||
"""获取模型使用分布."""
|
||||
with self._usage_lock:
|
||||
now = time.time()
|
||||
day_ago = now - 86400
|
||||
|
||||
# 统计24小时内的使用情况
|
||||
model_calls: Dict[str, int] = {}
|
||||
total_calls = 0
|
||||
|
||||
for usage in self._usage["daily"]:
|
||||
if usage["timestamp"] > day_ago:
|
||||
model = usage["model"]
|
||||
model_calls[model] = model_calls.get(model, 0) + 1
|
||||
total_calls += 1
|
||||
|
||||
if total_calls == 0:
|
||||
return {}
|
||||
|
||||
return {
|
||||
model: count / total_calls
|
||||
for model, count in model_calls.items()
|
||||
}
|
||||
|
||||
def _calculate_cost(self, model: str, prompt_tokens: int,
|
||||
completion_tokens: int) -> float:
|
||||
"""计算使用成本."""
|
||||
model_costs = self._costs.get(model)
|
||||
if not model_costs:
|
||||
return 0.0
|
||||
return model_costs.calculate(prompt_tokens, completion_tokens)
|
||||
|
||||
def _check_budget_limits(self, model: str, prompt_tokens: int,
|
||||
completion_tokens: int) -> bool:
|
||||
"""检查是否超出预算限制."""
|
||||
estimated_cost = self._calculate_cost(model, prompt_tokens, completion_tokens)
|
||||
current_costs = self.get_current_costs()
|
||||
|
||||
# 检查各个时间维度的预算限制
|
||||
if (current_costs["hourly"] + estimated_cost > self.limits.hourly_budget or
|
||||
current_costs["daily"] + estimated_cost > self.limits.daily_budget or
|
||||
current_costs["monthly"] + estimated_cost > self.limits.monthly_budget):
|
||||
LOGGER.warning(
|
||||
"Cost limit exceeded - model: %s, estimated: $%.4f",
|
||||
model, estimated_cost, extra=LOG_EXTRA
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _check_model_weights(self, model: str) -> bool:
|
||||
"""检查是否符合模型权重限制."""
|
||||
if model not in self.limits.model_weights:
|
||||
return True # 未配置权重的模型不限制
|
||||
|
||||
distribution = self.get_model_distribution()
|
||||
current_weight = distribution.get(model, 0.0)
|
||||
max_weight = self.limits.model_weights[model]
|
||||
|
||||
if current_weight >= max_weight:
|
||||
LOGGER.warning(
|
||||
"Model weight exceeded - model: %s, current: %.1f%%, max: %.1f%%",
|
||||
model, current_weight * 100, max_weight * 100, extra=LOG_EXTRA
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _cleanup_old_usage(self, current_time: float) -> None:
|
||||
"""清理过期的使用记录."""
|
||||
if current_time - self._last_cleanup < self._cleanup_interval:
|
||||
return
|
||||
|
||||
hour_ago = current_time - 3600
|
||||
day_ago = current_time - 86400
|
||||
month_ago = current_time - 2592000
|
||||
|
||||
self._usage["hourly"] = [
|
||||
u for u in self._usage["hourly"]
|
||||
if u["timestamp"] > hour_ago
|
||||
]
|
||||
self._usage["daily"] = [
|
||||
u for u in self._usage["daily"]
|
||||
if u["timestamp"] > day_ago
|
||||
]
|
||||
self._usage["monthly"] = [
|
||||
u for u in self._usage["monthly"]
|
||||
if u["timestamp"] > month_ago
|
||||
]
|
||||
|
||||
self._last_cleanup = current_time
|
||||
|
||||
|
||||
# 全局实例
|
||||
_controller = CostController()
|
||||
|
||||
def get_controller() -> CostController:
|
||||
"""获取全局CostController实例."""
|
||||
return _controller
|
||||
|
||||
def set_cost_limits(limits: CostLimits) -> None:
|
||||
"""设置全局成本限制."""
|
||||
_controller.limits = limits
|
||||
181
app/llm/version.py
Normal file
181
app/llm/version.py
Normal file
@ -0,0 +1,181 @@
|
||||
"""Template version management and validation."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from .templates import PromptTemplate
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
LOG_EXTRA = {"stage": "template_version"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemplateVersion:
|
||||
"""A versioned template configuration."""
|
||||
|
||||
id: str
|
||||
version: str
|
||||
created_at: str
|
||||
template: PromptTemplate
|
||||
metadata: Dict[str, Any]
|
||||
is_active: bool = False
|
||||
|
||||
@classmethod
|
||||
def create(cls, template: PromptTemplate, version: str,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> TemplateVersion:
|
||||
"""Create a new template version."""
|
||||
return cls(
|
||||
id=template.id,
|
||||
version=version,
|
||||
created_at=datetime.now().isoformat(),
|
||||
template=template,
|
||||
metadata=metadata or {},
|
||||
is_active=False
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary format."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"version": self.version,
|
||||
"created_at": self.created_at,
|
||||
"template": asdict(self.template),
|
||||
"metadata": self.metadata,
|
||||
"is_active": self.is_active
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> TemplateVersion:
|
||||
"""Create from dictionary format."""
|
||||
template_data = data["template"]
|
||||
template = PromptTemplate(
|
||||
id=template_data["id"],
|
||||
name=template_data["name"],
|
||||
description=template_data["description"],
|
||||
template=template_data["template"],
|
||||
variables=template_data["variables"],
|
||||
max_length=template_data.get("max_length", 4000),
|
||||
required_context=template_data.get("required_context"),
|
||||
validation_rules=template_data.get("validation_rules")
|
||||
)
|
||||
return cls(
|
||||
id=data["id"],
|
||||
version=data["version"],
|
||||
created_at=data["created_at"],
|
||||
template=template,
|
||||
metadata=data["metadata"],
|
||||
is_active=data.get("is_active", False)
|
||||
)
|
||||
|
||||
|
||||
class TemplateVersionManager:
|
||||
"""Manages template versioning and deployment."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize version manager."""
|
||||
self._versions: Dict[str, Dict[str, TemplateVersion]] = {}
|
||||
self._active_versions: Dict[str, str] = {}
|
||||
|
||||
def add_version(self, template: PromptTemplate, version: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
activate: bool = False) -> TemplateVersion:
|
||||
"""Add a new template version."""
|
||||
if template.id not in self._versions:
|
||||
self._versions[template.id] = {}
|
||||
|
||||
versions = self._versions[template.id]
|
||||
if version in versions:
|
||||
raise ValueError(f"Version {version} already exists for template {template.id}")
|
||||
|
||||
template_version = TemplateVersion.create(
|
||||
template=template,
|
||||
version=version,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
versions[version] = template_version
|
||||
if activate:
|
||||
self.activate_version(template.id, version)
|
||||
|
||||
return template_version
|
||||
|
||||
def get_version(self, template_id: str, version: str) -> Optional[TemplateVersion]:
|
||||
"""Get a specific template version."""
|
||||
return self._versions.get(template_id, {}).get(version)
|
||||
|
||||
def list_versions(self, template_id: str) -> List[TemplateVersion]:
|
||||
"""List all versions of a template."""
|
||||
return list(self._versions.get(template_id, {}).values())
|
||||
|
||||
def get_active_version(self, template_id: str) -> Optional[TemplateVersion]:
|
||||
"""Get the active version of a template."""
|
||||
active_version = self._active_versions.get(template_id)
|
||||
if active_version:
|
||||
return self.get_version(template_id, active_version)
|
||||
return None
|
||||
|
||||
def activate_version(self, template_id: str, version: str) -> None:
|
||||
"""Activate a specific template version."""
|
||||
if template_id not in self._versions:
|
||||
raise ValueError(f"Template {template_id} not found")
|
||||
|
||||
versions = self._versions[template_id]
|
||||
if version not in versions:
|
||||
raise ValueError(f"Version {version} not found for template {template_id}")
|
||||
|
||||
# Deactivate current active version
|
||||
current_active = self._active_versions.get(template_id)
|
||||
if current_active and current_active in versions:
|
||||
versions[current_active].is_active = False
|
||||
|
||||
# Activate new version
|
||||
versions[version].is_active = True
|
||||
self._active_versions[template_id] = version
|
||||
|
||||
def export_versions(self, template_id: str) -> str:
|
||||
"""Export all versions of a template to JSON."""
|
||||
if template_id not in self._versions:
|
||||
raise ValueError(f"Template {template_id} not found")
|
||||
|
||||
versions = self._versions[template_id]
|
||||
data = {
|
||||
"template_id": template_id,
|
||||
"active_version": self._active_versions.get(template_id),
|
||||
"versions": {
|
||||
version: ver.to_dict()
|
||||
for version, ver in versions.items()
|
||||
}
|
||||
}
|
||||
return json.dumps(data, indent=2)
|
||||
|
||||
def import_versions(self, json_str: str) -> None:
|
||||
"""Import template versions from JSON."""
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON format: {e}")
|
||||
|
||||
template_id = data.get("template_id")
|
||||
if not template_id:
|
||||
raise ValueError("Missing template_id in JSON")
|
||||
|
||||
versions = data.get("versions", {})
|
||||
if not versions:
|
||||
raise ValueError("No versions found in JSON")
|
||||
|
||||
# Clear existing versions for this template
|
||||
self._versions[template_id] = {}
|
||||
|
||||
# Import versions
|
||||
for version, ver_data in versions.items():
|
||||
template_version = TemplateVersion.from_dict(ver_data)
|
||||
self._versions[template_id][version] = template_version
|
||||
|
||||
# Set active version if specified
|
||||
active_version = data.get("active_version")
|
||||
if active_version:
|
||||
self.activate_version(template_id, active_version)
|
||||
@ -7,6 +7,7 @@
|
||||
> ✓ 数据访问与监控
|
||||
> ✓ 核心回测系统
|
||||
> ✓ LLM基础集成
|
||||
> ✓ LLM模板与上下文管理
|
||||
> △ RSS新闻处理
|
||||
> △ UI与监控系统
|
||||
|
||||
@ -65,10 +66,14 @@
|
||||
- [x] 精简和优化Provider管理
|
||||
- [x] 增强function-calling架构
|
||||
- [x] 完善错误处理和重试策略
|
||||
- [ ] 优化提示工程:
|
||||
- [ ] 设计配置化角色提示
|
||||
- [ ] 优化数据范围控制
|
||||
- [ ] 改进上下文管理
|
||||
- [x] 优化提示工程:
|
||||
- [x] 设计配置化角色提示
|
||||
- [x] 优化数据范围控制
|
||||
- [x] 改进上下文管理
|
||||
- [ ] 增强系统稳定性:
|
||||
- [ ] 实现提示模板版本管理
|
||||
- [ ] 增加系统级性能监控
|
||||
- [ ] 优化模型调用成本控制
|
||||
|
||||
### 4. UI与监控(P2)
|
||||
#### 4.1 功能增强
|
||||
@ -95,6 +100,8 @@
|
||||
3. ✓ 优化DataBroker的数据访问性能
|
||||
4. △ 完善RSS新闻数据源的接入
|
||||
5. ✓ 开始着手决策环境的增强
|
||||
6. ✓ 改进LLM模板与上下文管理
|
||||
7. △ 启动LLM性能与成本优化
|
||||
|
||||
## 三、开发原则
|
||||
|
||||
|
||||
100
tests/test_llm_cost.py
Normal file
100
tests/test_llm_cost.py
Normal file
@ -0,0 +1,100 @@
|
||||
"""Test cases for LLM cost control system."""
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from app.llm.cost import CostLimits, ModelCosts, CostController
|
||||
|
||||
|
||||
def test_cost_limits():
|
||||
"""Test cost limits configuration."""
|
||||
limits = CostLimits(
|
||||
hourly_budget=10.0,
|
||||
daily_budget=100.0,
|
||||
monthly_budget=1000.0,
|
||||
model_weights={"gpt-4": 0.7, "gpt-3.5-turbo": 0.3}
|
||||
)
|
||||
|
||||
assert limits.hourly_budget == 10.0
|
||||
assert limits.daily_budget == 100.0
|
||||
assert limits.monthly_budget == 1000.0
|
||||
assert limits.model_weights["gpt-4"] == 0.7
|
||||
|
||||
|
||||
def test_model_costs():
|
||||
"""Test model cost tracking."""
|
||||
costs = ModelCosts(
|
||||
prompt_cost_per_1k=0.1, # $0.1 per 1K tokens
|
||||
completion_cost_per_1k=0.2 # $0.2 per 1K tokens
|
||||
)
|
||||
|
||||
# Test cost calculation
|
||||
prompt_tokens = 1000 # 1K tokens
|
||||
completion_tokens = 500 # 0.5K tokens
|
||||
|
||||
total_cost = costs.calculate(prompt_tokens, completion_tokens)
|
||||
expected_cost = (prompt_tokens / 1000 * costs.prompt_cost_per_1k +
|
||||
completion_tokens / 1000 * costs.completion_cost_per_1k)
|
||||
|
||||
assert total_cost == expected_cost # Should be $0.2
|
||||
|
||||
|
||||
def test_cost_controller():
|
||||
"""Test cost controller functionality."""
|
||||
limits = CostLimits(
|
||||
hourly_budget=1.0,
|
||||
daily_budget=10.0,
|
||||
monthly_budget=100.0,
|
||||
model_weights={"gpt-4": 0.7, "gpt-3.5-turbo": 0.3}
|
||||
)
|
||||
|
||||
controller = CostController(limits=limits)
|
||||
|
||||
# First check if we can use model
|
||||
assert controller.can_use_model("gpt-4", 1000, 500)
|
||||
|
||||
# Then record the usage
|
||||
controller.record_usage("gpt-4", 1000, 500) # About $0.09
|
||||
|
||||
# Record usage for second model to maintain weight balance
|
||||
assert controller.can_use_model("gpt-3.5-turbo", 1000, 500)
|
||||
controller.record_usage("gpt-3.5-turbo", 1000, 500)
|
||||
|
||||
# Verify usage tracking
|
||||
costs = controller.get_current_costs()
|
||||
assert costs["hourly"] > 0
|
||||
assert costs["daily"] > 0
|
||||
assert costs["monthly"] > 0
|
||||
|
||||
# Test model distribution
|
||||
distribution = controller.get_model_distribution()
|
||||
assert "gpt-4" in distribution and "gpt-3.5-turbo" in distribution
|
||||
assert abs(distribution["gpt-4"] - 0.5) < 0.1 # Allow some deviation
|
||||
assert abs(distribution["gpt-3.5-turbo"] - 0.5) < 0.1 # Should be roughly balanced
|
||||
|
||||
|
||||
def test_cost_controller_history():
|
||||
"""Test cost controller usage history."""
|
||||
limits = CostLimits(
|
||||
hourly_budget=1.0,
|
||||
daily_budget=10.0,
|
||||
monthly_budget=100.0,
|
||||
model_weights={"gpt-4": 0.5, "gpt-3.5-turbo": 0.5} # Equal weights
|
||||
)
|
||||
|
||||
controller = CostController(limits=limits)
|
||||
|
||||
# Record one usage of each model
|
||||
assert controller.can_use_model("gpt-4", 1000, 500)
|
||||
controller.record_usage("gpt-4", 1000, 500)
|
||||
|
||||
assert controller.can_use_model("gpt-3.5-turbo", 1000, 500)
|
||||
controller.record_usage("gpt-3.5-turbo", 1000, 500)
|
||||
|
||||
# Check usage tracking
|
||||
costs = controller.get_current_costs()
|
||||
assert costs["hourly"] > 0 # Should have accumulated cost
|
||||
|
||||
# Verify the usage distribution is roughly balanced
|
||||
distribution = controller.get_model_distribution()
|
||||
assert abs(distribution["gpt-4"] - 0.5) < 0.1
|
||||
assert abs(distribution["gpt-3.5-turbo"] - 0.5) < 0.1
|
||||
93
tests/test_llm_version.py
Normal file
93
tests/test_llm_version.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Test cases for template version management."""
|
||||
import pytest
|
||||
|
||||
from app.llm.templates import PromptTemplate
|
||||
from app.llm.version import TemplateVersion, TemplateVersionManager
|
||||
|
||||
|
||||
def test_template_version_creation():
|
||||
"""Test creating and managing template versions."""
|
||||
template = PromptTemplate(
|
||||
id="test",
|
||||
name="Test Template",
|
||||
description="A test template",
|
||||
template="Test {var}",
|
||||
variables=["var"]
|
||||
)
|
||||
|
||||
version = TemplateVersion.create(
|
||||
template=template,
|
||||
version="1.0.0",
|
||||
metadata={"author": "test"}
|
||||
)
|
||||
|
||||
assert version.id == "test"
|
||||
assert version.version == "1.0.0"
|
||||
assert not version.is_active
|
||||
assert version.metadata["author"] == "test"
|
||||
|
||||
|
||||
def test_version_manager():
|
||||
"""Test version manager operations."""
|
||||
manager = TemplateVersionManager()
|
||||
|
||||
template = PromptTemplate(
|
||||
id="test",
|
||||
name="Test Template",
|
||||
description="A test template",
|
||||
template="Test {var}",
|
||||
variables=["var"]
|
||||
)
|
||||
|
||||
# Add version
|
||||
v1 = manager.add_version(template, "1.0.0")
|
||||
assert not v1.is_active
|
||||
|
||||
# Add and activate version
|
||||
v2 = manager.add_version(template, "2.0.0", activate=True)
|
||||
assert v2.is_active
|
||||
assert not v1.is_active
|
||||
|
||||
# Get version
|
||||
assert manager.get_version("test", "1.0.0") == v1
|
||||
assert manager.get_version("test", "2.0.0") == v2
|
||||
|
||||
# List versions
|
||||
versions = manager.list_versions("test")
|
||||
assert len(versions) == 2
|
||||
|
||||
# Get active version
|
||||
active = manager.get_active_version("test")
|
||||
assert active == v2
|
||||
|
||||
# Export and import
|
||||
exported = manager.export_versions("test")
|
||||
|
||||
new_manager = TemplateVersionManager()
|
||||
new_manager.import_versions(exported)
|
||||
|
||||
imported = new_manager.get_version("test", "2.0.0")
|
||||
assert imported.version == "2.0.0"
|
||||
assert imported.is_active
|
||||
|
||||
|
||||
def test_version_validation():
|
||||
"""Test version validation checks."""
|
||||
manager = TemplateVersionManager()
|
||||
|
||||
template = PromptTemplate(
|
||||
id="test",
|
||||
name="Test Template",
|
||||
description="A test template",
|
||||
template="Test {var}",
|
||||
variables=["var"]
|
||||
)
|
||||
|
||||
# Test duplicate version
|
||||
manager.add_version(template, "1.0.0")
|
||||
with pytest.raises(ValueError):
|
||||
manager.add_version(template, "1.0.0")
|
||||
|
||||
# Test invalid version
|
||||
with pytest.raises(ValueError):
|
||||
manager.activate_version("test", "invalid")
|
||||
Loading…
Reference in New Issue
Block a user