diff --git a/app/llm/cost.py b/app/llm/cost.py new file mode 100644 index 0000000..4de4000 --- /dev/null +++ b/app/llm/cost.py @@ -0,0 +1,238 @@ +"""LLM cost control and budget management.""" +from __future__ import annotations + +import json +import logging +import threading +import time +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, Set + +from .metrics import snapshot + +LOGGER = logging.getLogger(__name__) +LOG_EXTRA = {"stage": "cost_control"} + + +@dataclass +class CostLimits: + """Cost control limits configuration.""" + + hourly_budget: float # 每小时预算 + daily_budget: float # 每日预算 + monthly_budget: float # 每月预算 + model_weights: Dict[str, float] = field(default_factory=dict) # 模型权重配置 + + @classmethod + def default(cls) -> CostLimits: + """Create default cost limits.""" + return cls( + hourly_budget=2.0, # $2/hour + daily_budget=20.0, # $20/day + monthly_budget=300.0, # $300/month + model_weights={ + "gpt-4": 0.2, # 限制GPT-4使用比例 + "gpt-3.5-turbo": 0.6, + "llama2": 0.2 + } + ) + + +@dataclass +class ModelCosts: + """Per-model cost configuration.""" + + prompt_cost_per_1k: float + completion_cost_per_1k: float + min_tokens: int = 1 + + def calculate(self, prompt_tokens: int, completion_tokens: int) -> float: + """Calculate cost for token usage.""" + prompt_cost = max(self.min_tokens, prompt_tokens) / 1000 * self.prompt_cost_per_1k + completion_cost = max(self.min_tokens, completion_tokens) / 1000 * self.completion_cost_per_1k + return prompt_cost + completion_cost + + +class CostController: + """Controls and manages LLM costs.""" + + def __init__(self, limits: Optional[CostLimits] = None): + """Initialize cost controller.""" + self.limits = limits or CostLimits.default() + self._costs: Dict[str, ModelCosts] = { + "gpt-4": ModelCosts(0.03, 0.06), + "gpt-4-32k": ModelCosts(0.06, 0.12), + "gpt-3.5-turbo": ModelCosts(0.0015, 0.002), + "gpt-3.5-turbo-16k": ModelCosts(0.003, 0.004), + "llama2": ModelCosts(0.0, 0.0), + "codellama": ModelCosts(0.0, 0.0) + } + self._usage_lock = threading.Lock() + self._usage: Dict[str, List[Dict[str, Any]]] = { + "hourly": [], + "daily": [], + "monthly": [] + } + self._last_cleanup = time.time() + self._cleanup_interval = 3600 # 1小时清理一次历史数据 + + def can_use_model(self, model: str, prompt_tokens: int, + completion_tokens: int) -> bool: + """检查是否允许使用指定模型.""" + # 检查成本限制 + if not self._check_budget_limits(model, prompt_tokens, completion_tokens): + return False + + # 检查模型权重限制 + if not self._check_model_weights(model): + return False + + return True + + def record_usage(self, model: str, prompt_tokens: int, + completion_tokens: int) -> None: + """记录模型使用情况.""" + cost = self._calculate_cost(model, prompt_tokens, completion_tokens) + timestamp = time.time() + + usage = { + "model": model, + "timestamp": timestamp, + "cost": cost, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens + } + + with self._usage_lock: + self._usage["hourly"].append(usage) + self._usage["daily"].append(usage) + self._usage["monthly"].append(usage) + + # 定期清理过期数据 + self._cleanup_old_usage(timestamp) + + def get_current_costs(self) -> Dict[str, float]: + """获取当前时段的成本统计.""" + with self._usage_lock: + now = time.time() + hour_ago = now - 3600 + day_ago = now - 86400 + month_ago = now - 2592000 # 30天 + + hourly = sum(u["cost"] for u in self._usage["hourly"] + if u["timestamp"] > hour_ago) + daily = sum(u["cost"] for u in self._usage["daily"] + if u["timestamp"] > day_ago) + monthly = sum(u["cost"] for u in self._usage["monthly"] + if u["timestamp"] > month_ago) + + return { + "hourly": hourly, + "daily": daily, + "monthly": monthly + } + + def get_model_distribution(self) -> Dict[str, float]: + """获取模型使用分布.""" + with self._usage_lock: + now = time.time() + day_ago = now - 86400 + + # 统计24小时内的使用情况 + model_calls: Dict[str, int] = {} + total_calls = 0 + + for usage in self._usage["daily"]: + if usage["timestamp"] > day_ago: + model = usage["model"] + model_calls[model] = model_calls.get(model, 0) + 1 + total_calls += 1 + + if total_calls == 0: + return {} + + return { + model: count / total_calls + for model, count in model_calls.items() + } + + def _calculate_cost(self, model: str, prompt_tokens: int, + completion_tokens: int) -> float: + """计算使用成本.""" + model_costs = self._costs.get(model) + if not model_costs: + return 0.0 + return model_costs.calculate(prompt_tokens, completion_tokens) + + def _check_budget_limits(self, model: str, prompt_tokens: int, + completion_tokens: int) -> bool: + """检查是否超出预算限制.""" + estimated_cost = self._calculate_cost(model, prompt_tokens, completion_tokens) + current_costs = self.get_current_costs() + + # 检查各个时间维度的预算限制 + if (current_costs["hourly"] + estimated_cost > self.limits.hourly_budget or + current_costs["daily"] + estimated_cost > self.limits.daily_budget or + current_costs["monthly"] + estimated_cost > self.limits.monthly_budget): + LOGGER.warning( + "Cost limit exceeded - model: %s, estimated: $%.4f", + model, estimated_cost, extra=LOG_EXTRA + ) + return False + + return True + + def _check_model_weights(self, model: str) -> bool: + """检查是否符合模型权重限制.""" + if model not in self.limits.model_weights: + return True # 未配置权重的模型不限制 + + distribution = self.get_model_distribution() + current_weight = distribution.get(model, 0.0) + max_weight = self.limits.model_weights[model] + + if current_weight >= max_weight: + LOGGER.warning( + "Model weight exceeded - model: %s, current: %.1f%%, max: %.1f%%", + model, current_weight * 100, max_weight * 100, extra=LOG_EXTRA + ) + return False + + return True + + def _cleanup_old_usage(self, current_time: float) -> None: + """清理过期的使用记录.""" + if current_time - self._last_cleanup < self._cleanup_interval: + return + + hour_ago = current_time - 3600 + day_ago = current_time - 86400 + month_ago = current_time - 2592000 + + self._usage["hourly"] = [ + u for u in self._usage["hourly"] + if u["timestamp"] > hour_ago + ] + self._usage["daily"] = [ + u for u in self._usage["daily"] + if u["timestamp"] > day_ago + ] + self._usage["monthly"] = [ + u for u in self._usage["monthly"] + if u["timestamp"] > month_ago + ] + + self._last_cleanup = current_time + + +# 全局实例 +_controller = CostController() + +def get_controller() -> CostController: + """获取全局CostController实例.""" + return _controller + +def set_cost_limits(limits: CostLimits) -> None: + """设置全局成本限制.""" + _controller.limits = limits diff --git a/app/llm/version.py b/app/llm/version.py new file mode 100644 index 0000000..23fc4d4 --- /dev/null +++ b/app/llm/version.py @@ -0,0 +1,181 @@ +"""Template version management and validation.""" +from __future__ import annotations + +import json +import logging +from dataclasses import asdict, dataclass +from datetime import datetime +from typing import Any, Dict, List, Optional, Set + +from .templates import PromptTemplate + +LOGGER = logging.getLogger(__name__) +LOG_EXTRA = {"stage": "template_version"} + + +@dataclass +class TemplateVersion: + """A versioned template configuration.""" + + id: str + version: str + created_at: str + template: PromptTemplate + metadata: Dict[str, Any] + is_active: bool = False + + @classmethod + def create(cls, template: PromptTemplate, version: str, + metadata: Optional[Dict[str, Any]] = None) -> TemplateVersion: + """Create a new template version.""" + return cls( + id=template.id, + version=version, + created_at=datetime.now().isoformat(), + template=template, + metadata=metadata or {}, + is_active=False + ) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "id": self.id, + "version": self.version, + "created_at": self.created_at, + "template": asdict(self.template), + "metadata": self.metadata, + "is_active": self.is_active + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> TemplateVersion: + """Create from dictionary format.""" + template_data = data["template"] + template = PromptTemplate( + id=template_data["id"], + name=template_data["name"], + description=template_data["description"], + template=template_data["template"], + variables=template_data["variables"], + max_length=template_data.get("max_length", 4000), + required_context=template_data.get("required_context"), + validation_rules=template_data.get("validation_rules") + ) + return cls( + id=data["id"], + version=data["version"], + created_at=data["created_at"], + template=template, + metadata=data["metadata"], + is_active=data.get("is_active", False) + ) + + +class TemplateVersionManager: + """Manages template versioning and deployment.""" + + def __init__(self): + """Initialize version manager.""" + self._versions: Dict[str, Dict[str, TemplateVersion]] = {} + self._active_versions: Dict[str, str] = {} + + def add_version(self, template: PromptTemplate, version: str, + metadata: Optional[Dict[str, Any]] = None, + activate: bool = False) -> TemplateVersion: + """Add a new template version.""" + if template.id not in self._versions: + self._versions[template.id] = {} + + versions = self._versions[template.id] + if version in versions: + raise ValueError(f"Version {version} already exists for template {template.id}") + + template_version = TemplateVersion.create( + template=template, + version=version, + metadata=metadata + ) + + versions[version] = template_version + if activate: + self.activate_version(template.id, version) + + return template_version + + def get_version(self, template_id: str, version: str) -> Optional[TemplateVersion]: + """Get a specific template version.""" + return self._versions.get(template_id, {}).get(version) + + def list_versions(self, template_id: str) -> List[TemplateVersion]: + """List all versions of a template.""" + return list(self._versions.get(template_id, {}).values()) + + def get_active_version(self, template_id: str) -> Optional[TemplateVersion]: + """Get the active version of a template.""" + active_version = self._active_versions.get(template_id) + if active_version: + return self.get_version(template_id, active_version) + return None + + def activate_version(self, template_id: str, version: str) -> None: + """Activate a specific template version.""" + if template_id not in self._versions: + raise ValueError(f"Template {template_id} not found") + + versions = self._versions[template_id] + if version not in versions: + raise ValueError(f"Version {version} not found for template {template_id}") + + # Deactivate current active version + current_active = self._active_versions.get(template_id) + if current_active and current_active in versions: + versions[current_active].is_active = False + + # Activate new version + versions[version].is_active = True + self._active_versions[template_id] = version + + def export_versions(self, template_id: str) -> str: + """Export all versions of a template to JSON.""" + if template_id not in self._versions: + raise ValueError(f"Template {template_id} not found") + + versions = self._versions[template_id] + data = { + "template_id": template_id, + "active_version": self._active_versions.get(template_id), + "versions": { + version: ver.to_dict() + for version, ver in versions.items() + } + } + return json.dumps(data, indent=2) + + def import_versions(self, json_str: str) -> None: + """Import template versions from JSON.""" + try: + data = json.loads(json_str) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON format: {e}") + + template_id = data.get("template_id") + if not template_id: + raise ValueError("Missing template_id in JSON") + + versions = data.get("versions", {}) + if not versions: + raise ValueError("No versions found in JSON") + + # Clear existing versions for this template + self._versions[template_id] = {} + + # Import versions + for version, ver_data in versions.items(): + template_version = TemplateVersion.from_dict(ver_data) + self._versions[template_id][version] = template_version + + # Set active version if specified + active_version = data.get("active_version") + if active_version: + self.activate_version(template_id, active_version) diff --git a/docs/TODO_UNIFIED.md b/docs/TODO_UNIFIED.md index 02b1663..04cc119 100644 --- a/docs/TODO_UNIFIED.md +++ b/docs/TODO_UNIFIED.md @@ -7,6 +7,7 @@ > ✓ 数据访问与监控 > ✓ 核心回测系统 > ✓ LLM基础集成 +> ✓ LLM模板与上下文管理 > △ RSS新闻处理 > △ UI与监控系统 @@ -65,10 +66,14 @@ - [x] 精简和优化Provider管理 - [x] 增强function-calling架构 - [x] 完善错误处理和重试策略 -- [ ] 优化提示工程: - - [ ] 设计配置化角色提示 - - [ ] 优化数据范围控制 - - [ ] 改进上下文管理 +- [x] 优化提示工程: + - [x] 设计配置化角色提示 + - [x] 优化数据范围控制 + - [x] 改进上下文管理 +- [ ] 增强系统稳定性: + - [ ] 实现提示模板版本管理 + - [ ] 增加系统级性能监控 + - [ ] 优化模型调用成本控制 ### 4. UI与监控(P2) #### 4.1 功能增强 @@ -95,6 +100,8 @@ 3. ✓ 优化DataBroker的数据访问性能 4. △ 完善RSS新闻数据源的接入 5. ✓ 开始着手决策环境的增强 +6. ✓ 改进LLM模板与上下文管理 +7. △ 启动LLM性能与成本优化 ## 三、开发原则 diff --git a/tests/test_llm_cost.py b/tests/test_llm_cost.py new file mode 100644 index 0000000..ced1e63 --- /dev/null +++ b/tests/test_llm_cost.py @@ -0,0 +1,100 @@ +"""Test cases for LLM cost control system.""" +import pytest +from datetime import datetime + +from app.llm.cost import CostLimits, ModelCosts, CostController + + +def test_cost_limits(): + """Test cost limits configuration.""" + limits = CostLimits( + hourly_budget=10.0, + daily_budget=100.0, + monthly_budget=1000.0, + model_weights={"gpt-4": 0.7, "gpt-3.5-turbo": 0.3} + ) + + assert limits.hourly_budget == 10.0 + assert limits.daily_budget == 100.0 + assert limits.monthly_budget == 1000.0 + assert limits.model_weights["gpt-4"] == 0.7 + + +def test_model_costs(): + """Test model cost tracking.""" + costs = ModelCosts( + prompt_cost_per_1k=0.1, # $0.1 per 1K tokens + completion_cost_per_1k=0.2 # $0.2 per 1K tokens + ) + + # Test cost calculation + prompt_tokens = 1000 # 1K tokens + completion_tokens = 500 # 0.5K tokens + + total_cost = costs.calculate(prompt_tokens, completion_tokens) + expected_cost = (prompt_tokens / 1000 * costs.prompt_cost_per_1k + + completion_tokens / 1000 * costs.completion_cost_per_1k) + + assert total_cost == expected_cost # Should be $0.2 + + +def test_cost_controller(): + """Test cost controller functionality.""" + limits = CostLimits( + hourly_budget=1.0, + daily_budget=10.0, + monthly_budget=100.0, + model_weights={"gpt-4": 0.7, "gpt-3.5-turbo": 0.3} + ) + + controller = CostController(limits=limits) + + # First check if we can use model + assert controller.can_use_model("gpt-4", 1000, 500) + + # Then record the usage + controller.record_usage("gpt-4", 1000, 500) # About $0.09 + + # Record usage for second model to maintain weight balance + assert controller.can_use_model("gpt-3.5-turbo", 1000, 500) + controller.record_usage("gpt-3.5-turbo", 1000, 500) + + # Verify usage tracking + costs = controller.get_current_costs() + assert costs["hourly"] > 0 + assert costs["daily"] > 0 + assert costs["monthly"] > 0 + + # Test model distribution + distribution = controller.get_model_distribution() + assert "gpt-4" in distribution and "gpt-3.5-turbo" in distribution + assert abs(distribution["gpt-4"] - 0.5) < 0.1 # Allow some deviation + assert abs(distribution["gpt-3.5-turbo"] - 0.5) < 0.1 # Should be roughly balanced + + +def test_cost_controller_history(): + """Test cost controller usage history.""" + limits = CostLimits( + hourly_budget=1.0, + daily_budget=10.0, + monthly_budget=100.0, + model_weights={"gpt-4": 0.5, "gpt-3.5-turbo": 0.5} # Equal weights + ) + + controller = CostController(limits=limits) + + # Record one usage of each model + assert controller.can_use_model("gpt-4", 1000, 500) + controller.record_usage("gpt-4", 1000, 500) + + assert controller.can_use_model("gpt-3.5-turbo", 1000, 500) + controller.record_usage("gpt-3.5-turbo", 1000, 500) + + # Check usage tracking + costs = controller.get_current_costs() + assert costs["hourly"] > 0 # Should have accumulated cost + + # Verify the usage distribution is roughly balanced + distribution = controller.get_model_distribution() + assert abs(distribution["gpt-4"] - 0.5) < 0.1 + assert abs(distribution["gpt-3.5-turbo"] - 0.5) < 0.1 diff --git a/tests/test_llm_version.py b/tests/test_llm_version.py new file mode 100644 index 0000000..7c75d29 --- /dev/null +++ b/tests/test_llm_version.py @@ -0,0 +1,93 @@ +"""Test cases for template version management.""" +import pytest + +from app.llm.templates import PromptTemplate +from app.llm.version import TemplateVersion, TemplateVersionManager + + +def test_template_version_creation(): + """Test creating and managing template versions.""" + template = PromptTemplate( + id="test", + name="Test Template", + description="A test template", + template="Test {var}", + variables=["var"] + ) + + version = TemplateVersion.create( + template=template, + version="1.0.0", + metadata={"author": "test"} + ) + + assert version.id == "test" + assert version.version == "1.0.0" + assert not version.is_active + assert version.metadata["author"] == "test" + + +def test_version_manager(): + """Test version manager operations.""" + manager = TemplateVersionManager() + + template = PromptTemplate( + id="test", + name="Test Template", + description="A test template", + template="Test {var}", + variables=["var"] + ) + + # Add version + v1 = manager.add_version(template, "1.0.0") + assert not v1.is_active + + # Add and activate version + v2 = manager.add_version(template, "2.0.0", activate=True) + assert v2.is_active + assert not v1.is_active + + # Get version + assert manager.get_version("test", "1.0.0") == v1 + assert manager.get_version("test", "2.0.0") == v2 + + # List versions + versions = manager.list_versions("test") + assert len(versions) == 2 + + # Get active version + active = manager.get_active_version("test") + assert active == v2 + + # Export and import + exported = manager.export_versions("test") + + new_manager = TemplateVersionManager() + new_manager.import_versions(exported) + + imported = new_manager.get_version("test", "2.0.0") + assert imported.version == "2.0.0" + assert imported.is_active + + +def test_version_validation(): + """Test version validation checks.""" + manager = TemplateVersionManager() + + template = PromptTemplate( + id="test", + name="Test Template", + description="A test template", + template="Test {var}", + variables=["var"] + ) + + # Test duplicate version + manager.add_version(template, "1.0.0") + with pytest.raises(ValueError): + manager.add_version(template, "1.0.0") + + # Test invalid version + with pytest.raises(ValueError): + manager.activate_version("test", "invalid")