llm-quant/tests/test_llm_context.py
2025-10-05 16:28:53 +08:00

149 lines
4.3 KiB
Python

"""Test cases for LLM context management."""
import time
import pytest
from app.llm.context import Context, ContextConfig, ContextManager, DataAccessConfig, Message
def test_data_access_config():
"""Test data access configuration and validation."""
config = DataAccessConfig(
allowed_tables={"daily", "daily_basic"},
max_history_days=365,
max_batch_size=1000
)
# Valid request
errors = config.validate_request("daily", "20251001", "20251005")
assert not errors
# Invalid table
errors = config.validate_request("invalid", "20251001")
assert len(errors) == 1
assert "not allowed" in errors[0]
# Invalid date format
errors = config.validate_request("daily", "invalid")
assert len(errors) == 1
assert "Invalid date format" in errors[0]
# Date range too long
errors = config.validate_request("daily", "20251001", "20261001")
assert len(errors) == 1
assert "exceeds max" in errors[0]
# End date before start
errors = config.validate_request("daily", "20251005", "20251001")
assert len(errors) == 1
assert "before start date" in errors[0]
def test_context_config():
"""Test context configuration defaults."""
config = ContextConfig()
assert config.max_total_tokens > 0
assert config.max_messages > 0
assert config.include_system is True
assert config.include_functions is True
def test_message():
"""Test message functionality."""
# Basic message
msg = Message(role="user", content="Hello")
assert msg.role == "user"
assert msg.content == "Hello"
assert msg.name is None
assert msg.function_call is None
assert msg.timestamp <= time.time()
# Function message
func_msg = Message(
role="function",
content="Result",
name="test_func",
function_call={"name": "test_func", "arguments": "{}"}
)
assert func_msg.name == "test_func"
assert func_msg.function_call is not None
# Dict conversion
msg_dict = msg.to_dict()
assert msg_dict["role"] == "user"
assert msg_dict["content"] == "Hello"
assert "name" not in msg_dict
assert "function_call" not in msg_dict
func_dict = func_msg.to_dict()
assert func_dict["name"] == "test_func"
assert func_dict["function_call"] is not None
# Token estimation
assert msg.estimated_tokens > 0
assert func_msg.estimated_tokens > msg.estimated_tokens
def test_context():
"""Test context management."""
config = ContextConfig(max_total_tokens=100, max_messages=3)
context = Context(config=config)
# Add messages
msg1 = Message(role="system", content="System message")
msg2 = Message(role="user", content="User message")
msg3 = Message(role="assistant", content="Assistant message")
msg4 = Message(role="user", content="Another message")
context.add_message(msg1)
context.add_message(msg2)
context.add_message(msg3)
assert len(context.messages) == 3
# Test max messages
context.add_message(msg4)
assert len(context.messages) == 3
assert msg4 in context.messages # Newest message kept
# Get messages
all_msgs = context.get_messages()
assert len(all_msgs) == 3
no_system = context.get_messages(include_system=False)
assert len(no_system) == 2
assert all(m["role"] != "system" for m in no_system)
# Clear context
context.clear(keep_system=True)
assert len(context.messages) == 1
assert context.messages[0].role == "system"
context.clear(keep_system=False)
assert len(context.messages) == 0
def test_context_manager():
"""Test context manager functionality."""
ContextManager.clear_all()
# Create context
context = ContextManager.create_context("test")
assert ContextManager.get_context("test") == context
# Duplicate context
with pytest.raises(ValueError):
ContextManager.create_context("test")
# Custom config
config = ContextConfig(max_total_tokens=200)
custom = ContextManager.create_context("custom", config)
assert custom.config.max_total_tokens == 200
# Remove context
ContextManager.remove_context("test")
assert ContextManager.get_context("test") is None
# Clear all
ContextManager.clear_all()
assert ContextManager.get_context("custom") is None