173 lines
5.2 KiB
Python
173 lines
5.2 KiB
Python
"""Tests for token counter utility."""
|
|
|
|
import pytest
|
|
from src.utils.token_counter import TokenCounter, get_token_counter, count_tokens
|
|
|
|
|
|
def test_token_counter_initialization():
|
|
"""Test token counter initialization."""
|
|
counter = TokenCounter("gpt-4")
|
|
assert counter.model == "gpt-4"
|
|
assert counter._encoding is not None
|
|
|
|
|
|
def test_count_tokens_simple():
|
|
"""Test counting tokens in simple text."""
|
|
counter = TokenCounter()
|
|
|
|
text = "Hello, world!"
|
|
token_count = counter.count_tokens(text)
|
|
|
|
# Should return a reasonable token count
|
|
assert token_count > 0
|
|
assert token_count < 10 # "Hello, world!" should be few tokens
|
|
|
|
|
|
def test_count_tokens_empty():
|
|
"""Test counting tokens in empty string."""
|
|
counter = TokenCounter()
|
|
|
|
assert counter.count_tokens("") == 0
|
|
assert counter.count_tokens(None) == 0
|
|
|
|
|
|
def test_count_tokens_long_text():
|
|
"""Test counting tokens in longer text."""
|
|
counter = TokenCounter()
|
|
|
|
long_text = "This is a longer piece of text. " * 100
|
|
token_count = counter.count_tokens(long_text)
|
|
|
|
# Should be proportional to text length
|
|
assert token_count > 100
|
|
assert token_count < len(long_text) # Tokens should be less than characters
|
|
|
|
|
|
def test_count_message_tokens():
|
|
"""Test counting tokens in message format."""
|
|
counter = TokenCounter()
|
|
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Hello!"},
|
|
{"role": "assistant", "content": "Hi there!"}
|
|
]
|
|
|
|
token_count = counter.count_message_tokens(messages)
|
|
|
|
# Should include overhead for message formatting
|
|
assert token_count > 0
|
|
assert token_count > counter.count_tokens("You are a helpful assistant. Hello! Hi there!")
|
|
|
|
|
|
def test_count_message_tokens_empty():
|
|
"""Test counting tokens with empty messages."""
|
|
counter = TokenCounter()
|
|
|
|
assert counter.count_message_tokens([]) == 0
|
|
|
|
|
|
def test_truncate_to_token_limit():
|
|
"""Test truncating messages to fit token limit."""
|
|
counter = TokenCounter()
|
|
|
|
messages = [
|
|
{"role": "system", "content": "System message"},
|
|
{"role": "user", "content": "Message 1"},
|
|
{"role": "assistant", "content": "Response 1"},
|
|
{"role": "user", "content": "Message 2"},
|
|
{"role": "assistant", "content": "Response 2"},
|
|
{"role": "user", "content": "Message 3"},
|
|
{"role": "assistant", "content": "Response 3"},
|
|
]
|
|
|
|
# Set a low token limit
|
|
truncated = counter.truncate_to_token_limit(messages, max_tokens=50, keep_recent=True)
|
|
|
|
# Should keep system message
|
|
assert any(msg["role"] == "system" for msg in truncated)
|
|
|
|
# Should have fewer messages than original
|
|
assert len(truncated) < len(messages)
|
|
|
|
# Total tokens should be under limit
|
|
total_tokens = counter.count_message_tokens(truncated)
|
|
assert total_tokens <= 50
|
|
|
|
|
|
def test_truncate_keeps_system_message():
|
|
"""Test that truncation always keeps system messages."""
|
|
counter = TokenCounter()
|
|
|
|
messages = [
|
|
{"role": "system", "content": "Important system message"},
|
|
{"role": "user", "content": "User message " * 100},
|
|
{"role": "assistant", "content": "Assistant response " * 100},
|
|
]
|
|
|
|
truncated = counter.truncate_to_token_limit(messages, max_tokens=100, keep_recent=True)
|
|
|
|
# System message should always be present
|
|
system_messages = [msg for msg in truncated if msg["role"] == "system"]
|
|
assert len(system_messages) == 1
|
|
assert "Important system message" in system_messages[0]["content"]
|
|
|
|
|
|
def test_get_context_window_usage():
|
|
"""Test getting context window usage statistics."""
|
|
counter = TokenCounter()
|
|
|
|
messages = [
|
|
{"role": "user", "content": "Hello!"},
|
|
{"role": "assistant", "content": "Hi there!"}
|
|
]
|
|
|
|
stats = counter.get_context_window_usage(messages, context_window=8000)
|
|
|
|
assert "current_tokens" in stats
|
|
assert "max_tokens" in stats
|
|
assert "percentage" in stats
|
|
assert "remaining_tokens" in stats
|
|
assert "is_near_limit" in stats
|
|
|
|
assert stats["max_tokens"] == 8000
|
|
assert stats["current_tokens"] > 0
|
|
assert stats["percentage"] >= 0
|
|
assert stats["percentage"] < 100
|
|
assert not stats["is_near_limit"] # Should not be near limit with small messages
|
|
|
|
|
|
def test_get_context_window_usage_near_limit():
|
|
"""Test context window usage when near limit."""
|
|
counter = TokenCounter()
|
|
|
|
# Create message that uses most of the context window
|
|
large_message = "This is a long message. " * 500
|
|
|
|
messages = [
|
|
{"role": "user", "content": large_message},
|
|
]
|
|
|
|
stats = counter.get_context_window_usage(messages, context_window=1000)
|
|
|
|
# Should be near or over limit
|
|
assert stats["percentage"] > 50
|
|
|
|
|
|
def test_global_token_counter():
|
|
"""Test global token counter functions."""
|
|
# Test count_tokens function
|
|
token_count = count_tokens("Hello, world!")
|
|
assert token_count > 0
|
|
|
|
# Test get_token_counter
|
|
counter1 = get_token_counter()
|
|
counter2 = get_token_counter()
|
|
|
|
# Should return same instance
|
|
assert counter1 is counter2
|
|
|
|
# Test with specific model
|
|
counter3 = get_token_counter("gpt-3.5-turbo")
|
|
assert counter3.model == "gpt-3.5-turbo"
|