This commit is contained in:
sam 2025-10-05 17:24:10 +08:00
parent dd6e51400e
commit b4bd9fc9c5
5 changed files with 623 additions and 4 deletions

238
app/llm/cost.py Normal file
View File

@ -0,0 +1,238 @@
"""LLM cost control and budget management."""
from __future__ import annotations
import json
import logging
import threading
import time
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Set
from .metrics import snapshot
LOGGER = logging.getLogger(__name__)
LOG_EXTRA = {"stage": "cost_control"}
@dataclass
class CostLimits:
"""Cost control limits configuration."""
hourly_budget: float # 每小时预算
daily_budget: float # 每日预算
monthly_budget: float # 每月预算
model_weights: Dict[str, float] = field(default_factory=dict) # 模型权重配置
@classmethod
def default(cls) -> CostLimits:
"""Create default cost limits."""
return cls(
hourly_budget=2.0, # $2/hour
daily_budget=20.0, # $20/day
monthly_budget=300.0, # $300/month
model_weights={
"gpt-4": 0.2, # 限制GPT-4使用比例
"gpt-3.5-turbo": 0.6,
"llama2": 0.2
}
)
@dataclass
class ModelCosts:
"""Per-model cost configuration."""
prompt_cost_per_1k: float
completion_cost_per_1k: float
min_tokens: int = 1
def calculate(self, prompt_tokens: int, completion_tokens: int) -> float:
"""Calculate cost for token usage."""
prompt_cost = max(self.min_tokens, prompt_tokens) / 1000 * self.prompt_cost_per_1k
completion_cost = max(self.min_tokens, completion_tokens) / 1000 * self.completion_cost_per_1k
return prompt_cost + completion_cost
class CostController:
"""Controls and manages LLM costs."""
def __init__(self, limits: Optional[CostLimits] = None):
"""Initialize cost controller."""
self.limits = limits or CostLimits.default()
self._costs: Dict[str, ModelCosts] = {
"gpt-4": ModelCosts(0.03, 0.06),
"gpt-4-32k": ModelCosts(0.06, 0.12),
"gpt-3.5-turbo": ModelCosts(0.0015, 0.002),
"gpt-3.5-turbo-16k": ModelCosts(0.003, 0.004),
"llama2": ModelCosts(0.0, 0.0),
"codellama": ModelCosts(0.0, 0.0)
}
self._usage_lock = threading.Lock()
self._usage: Dict[str, List[Dict[str, Any]]] = {
"hourly": [],
"daily": [],
"monthly": []
}
self._last_cleanup = time.time()
self._cleanup_interval = 3600 # 1小时清理一次历史数据
def can_use_model(self, model: str, prompt_tokens: int,
completion_tokens: int) -> bool:
"""检查是否允许使用指定模型."""
# 检查成本限制
if not self._check_budget_limits(model, prompt_tokens, completion_tokens):
return False
# 检查模型权重限制
if not self._check_model_weights(model):
return False
return True
def record_usage(self, model: str, prompt_tokens: int,
completion_tokens: int) -> None:
"""记录模型使用情况."""
cost = self._calculate_cost(model, prompt_tokens, completion_tokens)
timestamp = time.time()
usage = {
"model": model,
"timestamp": timestamp,
"cost": cost,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens
}
with self._usage_lock:
self._usage["hourly"].append(usage)
self._usage["daily"].append(usage)
self._usage["monthly"].append(usage)
# 定期清理过期数据
self._cleanup_old_usage(timestamp)
def get_current_costs(self) -> Dict[str, float]:
"""获取当前时段的成本统计."""
with self._usage_lock:
now = time.time()
hour_ago = now - 3600
day_ago = now - 86400
month_ago = now - 2592000 # 30天
hourly = sum(u["cost"] for u in self._usage["hourly"]
if u["timestamp"] > hour_ago)
daily = sum(u["cost"] for u in self._usage["daily"]
if u["timestamp"] > day_ago)
monthly = sum(u["cost"] for u in self._usage["monthly"]
if u["timestamp"] > month_ago)
return {
"hourly": hourly,
"daily": daily,
"monthly": monthly
}
def get_model_distribution(self) -> Dict[str, float]:
"""获取模型使用分布."""
with self._usage_lock:
now = time.time()
day_ago = now - 86400
# 统计24小时内的使用情况
model_calls: Dict[str, int] = {}
total_calls = 0
for usage in self._usage["daily"]:
if usage["timestamp"] > day_ago:
model = usage["model"]
model_calls[model] = model_calls.get(model, 0) + 1
total_calls += 1
if total_calls == 0:
return {}
return {
model: count / total_calls
for model, count in model_calls.items()
}
def _calculate_cost(self, model: str, prompt_tokens: int,
completion_tokens: int) -> float:
"""计算使用成本."""
model_costs = self._costs.get(model)
if not model_costs:
return 0.0
return model_costs.calculate(prompt_tokens, completion_tokens)
def _check_budget_limits(self, model: str, prompt_tokens: int,
completion_tokens: int) -> bool:
"""检查是否超出预算限制."""
estimated_cost = self._calculate_cost(model, prompt_tokens, completion_tokens)
current_costs = self.get_current_costs()
# 检查各个时间维度的预算限制
if (current_costs["hourly"] + estimated_cost > self.limits.hourly_budget or
current_costs["daily"] + estimated_cost > self.limits.daily_budget or
current_costs["monthly"] + estimated_cost > self.limits.monthly_budget):
LOGGER.warning(
"Cost limit exceeded - model: %s, estimated: $%.4f",
model, estimated_cost, extra=LOG_EXTRA
)
return False
return True
def _check_model_weights(self, model: str) -> bool:
"""检查是否符合模型权重限制."""
if model not in self.limits.model_weights:
return True # 未配置权重的模型不限制
distribution = self.get_model_distribution()
current_weight = distribution.get(model, 0.0)
max_weight = self.limits.model_weights[model]
if current_weight >= max_weight:
LOGGER.warning(
"Model weight exceeded - model: %s, current: %.1f%%, max: %.1f%%",
model, current_weight * 100, max_weight * 100, extra=LOG_EXTRA
)
return False
return True
def _cleanup_old_usage(self, current_time: float) -> None:
"""清理过期的使用记录."""
if current_time - self._last_cleanup < self._cleanup_interval:
return
hour_ago = current_time - 3600
day_ago = current_time - 86400
month_ago = current_time - 2592000
self._usage["hourly"] = [
u for u in self._usage["hourly"]
if u["timestamp"] > hour_ago
]
self._usage["daily"] = [
u for u in self._usage["daily"]
if u["timestamp"] > day_ago
]
self._usage["monthly"] = [
u for u in self._usage["monthly"]
if u["timestamp"] > month_ago
]
self._last_cleanup = current_time
# 全局实例
_controller = CostController()
def get_controller() -> CostController:
"""获取全局CostController实例."""
return _controller
def set_cost_limits(limits: CostLimits) -> None:
"""设置全局成本限制."""
_controller.limits = limits

181
app/llm/version.py Normal file
View File

@ -0,0 +1,181 @@
"""Template version management and validation."""
from __future__ import annotations
import json
import logging
from dataclasses import asdict, dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional, Set
from .templates import PromptTemplate
LOGGER = logging.getLogger(__name__)
LOG_EXTRA = {"stage": "template_version"}
@dataclass
class TemplateVersion:
"""A versioned template configuration."""
id: str
version: str
created_at: str
template: PromptTemplate
metadata: Dict[str, Any]
is_active: bool = False
@classmethod
def create(cls, template: PromptTemplate, version: str,
metadata: Optional[Dict[str, Any]] = None) -> TemplateVersion:
"""Create a new template version."""
return cls(
id=template.id,
version=version,
created_at=datetime.now().isoformat(),
template=template,
metadata=metadata or {},
is_active=False
)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"id": self.id,
"version": self.version,
"created_at": self.created_at,
"template": asdict(self.template),
"metadata": self.metadata,
"is_active": self.is_active
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> TemplateVersion:
"""Create from dictionary format."""
template_data = data["template"]
template = PromptTemplate(
id=template_data["id"],
name=template_data["name"],
description=template_data["description"],
template=template_data["template"],
variables=template_data["variables"],
max_length=template_data.get("max_length", 4000),
required_context=template_data.get("required_context"),
validation_rules=template_data.get("validation_rules")
)
return cls(
id=data["id"],
version=data["version"],
created_at=data["created_at"],
template=template,
metadata=data["metadata"],
is_active=data.get("is_active", False)
)
class TemplateVersionManager:
"""Manages template versioning and deployment."""
def __init__(self):
"""Initialize version manager."""
self._versions: Dict[str, Dict[str, TemplateVersion]] = {}
self._active_versions: Dict[str, str] = {}
def add_version(self, template: PromptTemplate, version: str,
metadata: Optional[Dict[str, Any]] = None,
activate: bool = False) -> TemplateVersion:
"""Add a new template version."""
if template.id not in self._versions:
self._versions[template.id] = {}
versions = self._versions[template.id]
if version in versions:
raise ValueError(f"Version {version} already exists for template {template.id}")
template_version = TemplateVersion.create(
template=template,
version=version,
metadata=metadata
)
versions[version] = template_version
if activate:
self.activate_version(template.id, version)
return template_version
def get_version(self, template_id: str, version: str) -> Optional[TemplateVersion]:
"""Get a specific template version."""
return self._versions.get(template_id, {}).get(version)
def list_versions(self, template_id: str) -> List[TemplateVersion]:
"""List all versions of a template."""
return list(self._versions.get(template_id, {}).values())
def get_active_version(self, template_id: str) -> Optional[TemplateVersion]:
"""Get the active version of a template."""
active_version = self._active_versions.get(template_id)
if active_version:
return self.get_version(template_id, active_version)
return None
def activate_version(self, template_id: str, version: str) -> None:
"""Activate a specific template version."""
if template_id not in self._versions:
raise ValueError(f"Template {template_id} not found")
versions = self._versions[template_id]
if version not in versions:
raise ValueError(f"Version {version} not found for template {template_id}")
# Deactivate current active version
current_active = self._active_versions.get(template_id)
if current_active and current_active in versions:
versions[current_active].is_active = False
# Activate new version
versions[version].is_active = True
self._active_versions[template_id] = version
def export_versions(self, template_id: str) -> str:
"""Export all versions of a template to JSON."""
if template_id not in self._versions:
raise ValueError(f"Template {template_id} not found")
versions = self._versions[template_id]
data = {
"template_id": template_id,
"active_version": self._active_versions.get(template_id),
"versions": {
version: ver.to_dict()
for version, ver in versions.items()
}
}
return json.dumps(data, indent=2)
def import_versions(self, json_str: str) -> None:
"""Import template versions from JSON."""
try:
data = json.loads(json_str)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {e}")
template_id = data.get("template_id")
if not template_id:
raise ValueError("Missing template_id in JSON")
versions = data.get("versions", {})
if not versions:
raise ValueError("No versions found in JSON")
# Clear existing versions for this template
self._versions[template_id] = {}
# Import versions
for version, ver_data in versions.items():
template_version = TemplateVersion.from_dict(ver_data)
self._versions[template_id][version] = template_version
# Set active version if specified
active_version = data.get("active_version")
if active_version:
self.activate_version(template_id, active_version)

View File

@ -7,6 +7,7 @@
> ✓ 数据访问与监控
> ✓ 核心回测系统
> ✓ LLM基础集成
> ✓ LLM模板与上下文管理
> △ RSS新闻处理
> △ UI与监控系统
@ -65,10 +66,14 @@
- [x] 精简和优化Provider管理
- [x] 增强function-calling架构
- [x] 完善错误处理和重试策略
- [ ] 优化提示工程:
- [ ] 设计配置化角色提示
- [ ] 优化数据范围控制
- [ ] 改进上下文管理
- [x] 优化提示工程:
- [x] 设计配置化角色提示
- [x] 优化数据范围控制
- [x] 改进上下文管理
- [ ] 增强系统稳定性:
- [ ] 实现提示模板版本管理
- [ ] 增加系统级性能监控
- [ ] 优化模型调用成本控制
### 4. UI与监控P2
#### 4.1 功能增强
@ -95,6 +100,8 @@
3. ✓ 优化DataBroker的数据访问性能
4. △ 完善RSS新闻数据源的接入
5. ✓ 开始着手决策环境的增强
6. ✓ 改进LLM模板与上下文管理
7. △ 启动LLM性能与成本优化
## 三、开发原则

100
tests/test_llm_cost.py Normal file
View File

@ -0,0 +1,100 @@
"""Test cases for LLM cost control system."""
import pytest
from datetime import datetime
from app.llm.cost import CostLimits, ModelCosts, CostController
def test_cost_limits():
"""Test cost limits configuration."""
limits = CostLimits(
hourly_budget=10.0,
daily_budget=100.0,
monthly_budget=1000.0,
model_weights={"gpt-4": 0.7, "gpt-3.5-turbo": 0.3}
)
assert limits.hourly_budget == 10.0
assert limits.daily_budget == 100.0
assert limits.monthly_budget == 1000.0
assert limits.model_weights["gpt-4"] == 0.7
def test_model_costs():
"""Test model cost tracking."""
costs = ModelCosts(
prompt_cost_per_1k=0.1, # $0.1 per 1K tokens
completion_cost_per_1k=0.2 # $0.2 per 1K tokens
)
# Test cost calculation
prompt_tokens = 1000 # 1K tokens
completion_tokens = 500 # 0.5K tokens
total_cost = costs.calculate(prompt_tokens, completion_tokens)
expected_cost = (prompt_tokens / 1000 * costs.prompt_cost_per_1k +
completion_tokens / 1000 * costs.completion_cost_per_1k)
assert total_cost == expected_cost # Should be $0.2
def test_cost_controller():
"""Test cost controller functionality."""
limits = CostLimits(
hourly_budget=1.0,
daily_budget=10.0,
monthly_budget=100.0,
model_weights={"gpt-4": 0.7, "gpt-3.5-turbo": 0.3}
)
controller = CostController(limits=limits)
# First check if we can use model
assert controller.can_use_model("gpt-4", 1000, 500)
# Then record the usage
controller.record_usage("gpt-4", 1000, 500) # About $0.09
# Record usage for second model to maintain weight balance
assert controller.can_use_model("gpt-3.5-turbo", 1000, 500)
controller.record_usage("gpt-3.5-turbo", 1000, 500)
# Verify usage tracking
costs = controller.get_current_costs()
assert costs["hourly"] > 0
assert costs["daily"] > 0
assert costs["monthly"] > 0
# Test model distribution
distribution = controller.get_model_distribution()
assert "gpt-4" in distribution and "gpt-3.5-turbo" in distribution
assert abs(distribution["gpt-4"] - 0.5) < 0.1 # Allow some deviation
assert abs(distribution["gpt-3.5-turbo"] - 0.5) < 0.1 # Should be roughly balanced
def test_cost_controller_history():
"""Test cost controller usage history."""
limits = CostLimits(
hourly_budget=1.0,
daily_budget=10.0,
monthly_budget=100.0,
model_weights={"gpt-4": 0.5, "gpt-3.5-turbo": 0.5} # Equal weights
)
controller = CostController(limits=limits)
# Record one usage of each model
assert controller.can_use_model("gpt-4", 1000, 500)
controller.record_usage("gpt-4", 1000, 500)
assert controller.can_use_model("gpt-3.5-turbo", 1000, 500)
controller.record_usage("gpt-3.5-turbo", 1000, 500)
# Check usage tracking
costs = controller.get_current_costs()
assert costs["hourly"] > 0 # Should have accumulated cost
# Verify the usage distribution is roughly balanced
distribution = controller.get_model_distribution()
assert abs(distribution["gpt-4"] - 0.5) < 0.1
assert abs(distribution["gpt-3.5-turbo"] - 0.5) < 0.1

93
tests/test_llm_version.py Normal file
View File

@ -0,0 +1,93 @@
"""Test cases for template version management."""
import pytest
from app.llm.templates import PromptTemplate
from app.llm.version import TemplateVersion, TemplateVersionManager
def test_template_version_creation():
"""Test creating and managing template versions."""
template = PromptTemplate(
id="test",
name="Test Template",
description="A test template",
template="Test {var}",
variables=["var"]
)
version = TemplateVersion.create(
template=template,
version="1.0.0",
metadata={"author": "test"}
)
assert version.id == "test"
assert version.version == "1.0.0"
assert not version.is_active
assert version.metadata["author"] == "test"
def test_version_manager():
"""Test version manager operations."""
manager = TemplateVersionManager()
template = PromptTemplate(
id="test",
name="Test Template",
description="A test template",
template="Test {var}",
variables=["var"]
)
# Add version
v1 = manager.add_version(template, "1.0.0")
assert not v1.is_active
# Add and activate version
v2 = manager.add_version(template, "2.0.0", activate=True)
assert v2.is_active
assert not v1.is_active
# Get version
assert manager.get_version("test", "1.0.0") == v1
assert manager.get_version("test", "2.0.0") == v2
# List versions
versions = manager.list_versions("test")
assert len(versions) == 2
# Get active version
active = manager.get_active_version("test")
assert active == v2
# Export and import
exported = manager.export_versions("test")
new_manager = TemplateVersionManager()
new_manager.import_versions(exported)
imported = new_manager.get_version("test", "2.0.0")
assert imported.version == "2.0.0"
assert imported.is_active
def test_version_validation():
"""Test version validation checks."""
manager = TemplateVersionManager()
template = PromptTemplate(
id="test",
name="Test Template",
description="A test template",
template="Test {var}",
variables=["var"]
)
# Test duplicate version
manager.add_version(template, "1.0.0")
with pytest.raises(ValueError):
manager.add_version(template, "1.0.0")
# Test invalid version
with pytest.raises(ValueError):
manager.activate_version("test", "invalid")