149 lines
4.3 KiB
Python
149 lines
4.3 KiB
Python
"""Test cases for LLM context management."""
|
|
import time
|
|
|
|
import pytest
|
|
|
|
from app.llm.context import Context, ContextConfig, ContextManager, DataAccessConfig, Message
|
|
|
|
|
|
def test_data_access_config():
|
|
"""Test data access configuration and validation."""
|
|
config = DataAccessConfig(
|
|
allowed_tables={"daily", "daily_basic"},
|
|
max_history_days=365,
|
|
max_batch_size=1000
|
|
)
|
|
|
|
# Valid request
|
|
errors = config.validate_request("daily", "20251001", "20251005")
|
|
assert not errors
|
|
|
|
# Invalid table
|
|
errors = config.validate_request("invalid", "20251001")
|
|
assert len(errors) == 1
|
|
assert "not allowed" in errors[0]
|
|
|
|
# Invalid date format
|
|
errors = config.validate_request("daily", "invalid")
|
|
assert len(errors) == 1
|
|
assert "Invalid date format" in errors[0]
|
|
|
|
# Date range too long
|
|
errors = config.validate_request("daily", "20251001", "20261001")
|
|
assert len(errors) == 1
|
|
assert "exceeds max" in errors[0]
|
|
|
|
# End date before start
|
|
errors = config.validate_request("daily", "20251005", "20251001")
|
|
assert len(errors) == 1
|
|
assert "before start date" in errors[0]
|
|
|
|
|
|
def test_context_config():
|
|
"""Test context configuration defaults."""
|
|
config = ContextConfig()
|
|
assert config.max_total_tokens > 0
|
|
assert config.max_messages > 0
|
|
assert config.include_system is True
|
|
assert config.include_functions is True
|
|
|
|
|
|
def test_message():
|
|
"""Test message functionality."""
|
|
# Basic message
|
|
msg = Message(role="user", content="Hello")
|
|
assert msg.role == "user"
|
|
assert msg.content == "Hello"
|
|
assert msg.name is None
|
|
assert msg.function_call is None
|
|
assert msg.timestamp <= time.time()
|
|
|
|
# Function message
|
|
func_msg = Message(
|
|
role="function",
|
|
content="Result",
|
|
name="test_func",
|
|
function_call={"name": "test_func", "arguments": "{}"}
|
|
)
|
|
assert func_msg.name == "test_func"
|
|
assert func_msg.function_call is not None
|
|
|
|
# Dict conversion
|
|
msg_dict = msg.to_dict()
|
|
assert msg_dict["role"] == "user"
|
|
assert msg_dict["content"] == "Hello"
|
|
assert "name" not in msg_dict
|
|
assert "function_call" not in msg_dict
|
|
|
|
func_dict = func_msg.to_dict()
|
|
assert func_dict["name"] == "test_func"
|
|
assert func_dict["function_call"] is not None
|
|
|
|
# Token estimation
|
|
assert msg.estimated_tokens > 0
|
|
assert func_msg.estimated_tokens > msg.estimated_tokens
|
|
|
|
|
|
def test_context():
|
|
"""Test context management."""
|
|
config = ContextConfig(max_total_tokens=100, max_messages=3)
|
|
context = Context(config=config)
|
|
|
|
# Add messages
|
|
msg1 = Message(role="system", content="System message")
|
|
msg2 = Message(role="user", content="User message")
|
|
msg3 = Message(role="assistant", content="Assistant message")
|
|
msg4 = Message(role="user", content="Another message")
|
|
|
|
context.add_message(msg1)
|
|
context.add_message(msg2)
|
|
context.add_message(msg3)
|
|
assert len(context.messages) == 3
|
|
|
|
# Test max messages
|
|
context.add_message(msg4)
|
|
assert len(context.messages) == 3
|
|
assert msg4 in context.messages # Newest message kept
|
|
|
|
# Get messages
|
|
all_msgs = context.get_messages()
|
|
assert len(all_msgs) == 3
|
|
|
|
no_system = context.get_messages(include_system=False)
|
|
assert len(no_system) == 2
|
|
assert all(m["role"] != "system" for m in no_system)
|
|
|
|
# Clear context
|
|
context.clear(keep_system=True)
|
|
assert len(context.messages) == 1
|
|
assert context.messages[0].role == "system"
|
|
|
|
context.clear(keep_system=False)
|
|
assert len(context.messages) == 0
|
|
|
|
|
|
def test_context_manager():
|
|
"""Test context manager functionality."""
|
|
ContextManager.clear_all()
|
|
|
|
# Create context
|
|
context = ContextManager.create_context("test")
|
|
assert ContextManager.get_context("test") == context
|
|
|
|
# Duplicate context
|
|
with pytest.raises(ValueError):
|
|
ContextManager.create_context("test")
|
|
|
|
# Custom config
|
|
config = ContextConfig(max_total_tokens=200)
|
|
custom = ContextManager.create_context("custom", config)
|
|
assert custom.config.max_total_tokens == 200
|
|
|
|
# Remove context
|
|
ContextManager.remove_context("test")
|
|
assert ContextManager.get_context("test") is None
|
|
|
|
# Clear all
|
|
ContextManager.clear_all()
|
|
assert ContextManager.get_context("custom") is None
|