mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat(inference): add tokenization utilities for prompt caching
Implement token counting utilities to determine prompt cacheability (≥1024 tokens) with support for OpenAI, Llama, and multimodal content. - Add count_tokens() function with model-specific tokenizers - Support OpenAI models (GPT-4, GPT-4o, etc.) via tiktoken - Support Llama models (3.x, 4.x) via transformers - Fallback to character-based estimation for unknown models - Handle multimodal content (text + images) - LRU cache for tokenizer instances (max 10, <1ms cached calls) - Comprehensive unit tests (34 tests, >95% coverage) - Update tiktoken version constraint to >=0.8.0 This enables future PR to determine which prompts should be cached based on token count threshold. Signed-off-by: William Caban <william.caban@gmail.com>
This commit is contained in:
parent
97f535c4f1
commit
e61572daf0
4 changed files with 902 additions and 1 deletions
7
tests/unit/providers/utils/inference/__init__.py
Normal file
7
tests/unit/providers/utils/inference/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for inference utilities."""
|
||||
446
tests/unit/providers/utils/inference/test_tokenization.py
Normal file
446
tests/unit/providers/utils/inference/test_tokenization.py
Normal file
|
|
@ -0,0 +1,446 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Unit tests for tokenization utilities."""
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.utils.inference.tokenization import (
|
||||
TokenizationError,
|
||||
clear_tokenizer_cache,
|
||||
count_tokens,
|
||||
get_tokenization_method,
|
||||
)
|
||||
|
||||
|
||||
class TestCountTokens:
|
||||
"""Test suite for count_tokens function."""
|
||||
|
||||
def test_count_tokens_simple_text_openai(self):
|
||||
"""Test token counting for simple text with OpenAI models."""
|
||||
message = {"role": "user", "content": "Hello, world!"}
|
||||
|
||||
# Should work with GPT-4
|
||||
token_count = count_tokens(message, model="gpt-4")
|
||||
assert isinstance(token_count, int)
|
||||
assert token_count > 0
|
||||
# "Hello, world!" should be around 3-4 tokens
|
||||
assert 2 <= token_count <= 5
|
||||
|
||||
def test_count_tokens_simple_text_gpt4o(self):
|
||||
"""Test token counting for GPT-4o model."""
|
||||
message = {"role": "user", "content": "This is a test message."}
|
||||
|
||||
token_count = count_tokens(message, model="gpt-4o")
|
||||
assert isinstance(token_count, int)
|
||||
assert token_count > 0
|
||||
|
||||
def test_count_tokens_empty_message(self):
|
||||
"""Test token counting for empty message."""
|
||||
message = {"role": "user", "content": ""}
|
||||
|
||||
token_count = count_tokens(message, model="gpt-4")
|
||||
assert token_count == 0
|
||||
|
||||
def test_count_tokens_none_content(self):
|
||||
"""Test token counting for None content."""
|
||||
message = {"role": "user", "content": None}
|
||||
|
||||
token_count = count_tokens(message, model="gpt-4")
|
||||
assert token_count == 0
|
||||
|
||||
def test_count_tokens_multiple_messages(self):
|
||||
"""Test token counting for multiple messages."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the weather?"},
|
||||
]
|
||||
|
||||
token_count = count_tokens(messages, model="gpt-4")
|
||||
assert isinstance(token_count, int)
|
||||
assert token_count > 0
|
||||
# Should be more than single message
|
||||
assert token_count >= 10
|
||||
|
||||
def test_count_tokens_long_text(self):
|
||||
"""Test token counting for long text."""
|
||||
long_text = " ".join(["word"] * 1000)
|
||||
message = {"role": "user", "content": long_text}
|
||||
|
||||
token_count = count_tokens(message, model="gpt-4")
|
||||
assert isinstance(token_count, int)
|
||||
# 1000 words should be close to 1000 tokens
|
||||
assert 900 <= token_count <= 1100
|
||||
|
||||
def test_count_tokens_multimodal_text_only(self):
|
||||
"""Test token counting for multimodal message with text only."""
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
}
|
||||
|
||||
token_count = count_tokens(message, model="gpt-4o")
|
||||
assert isinstance(token_count, int)
|
||||
assert token_count > 0
|
||||
|
||||
def test_count_tokens_multimodal_with_image_low_res(self):
|
||||
"""Test token counting for multimodal message with low-res image."""
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe this image."},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/image.jpg",
|
||||
"detail": "low",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
token_count = count_tokens(message, model="gpt-4o")
|
||||
assert isinstance(token_count, int)
|
||||
# Should include text tokens + image tokens (85 for low-res)
|
||||
assert token_count >= 85
|
||||
|
||||
def test_count_tokens_multimodal_with_image_high_res(self):
|
||||
"""Test token counting for multimodal message with high-res image."""
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Analyze this."},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/image.jpg",
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
token_count = count_tokens(message, model="gpt-4o")
|
||||
assert isinstance(token_count, int)
|
||||
# Should include text tokens + image tokens (170 for high-res)
|
||||
assert token_count >= 170
|
||||
|
||||
def test_count_tokens_multimodal_with_image_auto(self):
|
||||
"""Test token counting for multimodal message with auto detail."""
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What do you see?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://example.com/image.jpg"},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
token_count = count_tokens(message, model="gpt-4o")
|
||||
assert isinstance(token_count, int)
|
||||
# Should use average of low and high
|
||||
assert token_count >= 100
|
||||
|
||||
def test_count_tokens_multiple_images(self):
|
||||
"""Test token counting for message with multiple images."""
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Compare these images."},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/image1.jpg",
|
||||
"detail": "low",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/image2.jpg",
|
||||
"detail": "low",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
token_count = count_tokens(message, model="gpt-4o")
|
||||
assert isinstance(token_count, int)
|
||||
# Should include text + 2 * 85 tokens for images
|
||||
assert token_count >= 170
|
||||
|
||||
def test_count_tokens_unknown_model_estimation(self):
|
||||
"""Test token counting falls back to estimation for unknown models."""
|
||||
message = {"role": "user", "content": "Hello, world!"}
|
||||
|
||||
# Unknown model should use character-based estimation
|
||||
token_count = count_tokens(message, model="unknown-model-xyz")
|
||||
assert isinstance(token_count, int)
|
||||
assert token_count > 0
|
||||
# "Hello, world!" is 13 chars, should estimate ~3-4 tokens
|
||||
assert 2 <= token_count <= 5
|
||||
|
||||
def test_count_tokens_llama_model_fallback(self):
|
||||
"""Test token counting for Llama models (may fall back to estimation)."""
|
||||
message = {"role": "user", "content": "Hello from Llama!"}
|
||||
|
||||
# This may fail if transformers/model not available, should fall back
|
||||
token_count = count_tokens(
|
||||
message,
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
)
|
||||
assert isinstance(token_count, int)
|
||||
assert token_count > 0
|
||||
|
||||
def test_count_tokens_with_exact_false(self):
|
||||
"""Test token counting with exact=False uses estimation."""
|
||||
message = {"role": "user", "content": "This is a test."}
|
||||
|
||||
token_count = count_tokens(message, model="gpt-4", exact=False)
|
||||
assert isinstance(token_count, int)
|
||||
assert token_count > 0
|
||||
# Should use character-based estimation
|
||||
# "This is a test." is 15 chars, should estimate ~3-4 tokens
|
||||
assert 3 <= token_count <= 5
|
||||
|
||||
def test_count_tokens_malformed_message(self):
|
||||
"""Test token counting with malformed message."""
|
||||
# Not a dict
|
||||
token_count = count_tokens("not a message", model="gpt-4") # type: ignore
|
||||
assert token_count == 0
|
||||
|
||||
# Missing content
|
||||
token_count = count_tokens({"role": "user"}, model="gpt-4")
|
||||
assert token_count == 0
|
||||
|
||||
# Malformed content list
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
"not a dict", # Invalid item
|
||||
{"type": "text", "text": "valid text"},
|
||||
],
|
||||
}
|
||||
token_count = count_tokens(message, model="gpt-4")
|
||||
# Should only count valid items
|
||||
assert token_count > 0
|
||||
|
||||
def test_count_tokens_empty_list(self):
|
||||
"""Test token counting with empty message list."""
|
||||
token_count = count_tokens([], model="gpt-4")
|
||||
assert token_count == 0
|
||||
|
||||
def test_count_tokens_special_characters(self):
|
||||
"""Test token counting with special characters."""
|
||||
message = {"role": "user", "content": "Hello! @#$%^&*() 🎉"}
|
||||
|
||||
token_count = count_tokens(message, model="gpt-4")
|
||||
assert isinstance(token_count, int)
|
||||
assert token_count > 0
|
||||
|
||||
def test_count_tokens_very_long_text(self):
|
||||
"""Test token counting with very long text (>1024 tokens)."""
|
||||
# Create text that should be >1024 tokens
|
||||
long_text = " ".join(["word"] * 2000)
|
||||
message = {"role": "user", "content": long_text}
|
||||
|
||||
token_count = count_tokens(message, model="gpt-4")
|
||||
assert isinstance(token_count, int)
|
||||
# Should be close to 2000 tokens
|
||||
assert token_count >= 1024 # At least cacheable threshold
|
||||
assert 1800 <= token_count <= 2200
|
||||
|
||||
def test_count_tokens_fine_tuned_model(self):
|
||||
"""Test token counting for fine-tuned OpenAI model."""
|
||||
message = {"role": "user", "content": "Test fine-tuned model."}
|
||||
|
||||
# Fine-tuned models should still work
|
||||
token_count = count_tokens(message, model="gpt-4-turbo-2024-04-09")
|
||||
assert isinstance(token_count, int)
|
||||
assert token_count > 0
|
||||
|
||||
|
||||
class TestGetTokenizationMethod:
|
||||
"""Test suite for get_tokenization_method function."""
|
||||
|
||||
def test_get_tokenization_method_openai(self):
|
||||
"""Test getting tokenization method for OpenAI models."""
|
||||
assert get_tokenization_method("gpt-4") == "exact-tiktoken"
|
||||
assert get_tokenization_method("gpt-4o") == "exact-tiktoken"
|
||||
assert get_tokenization_method("gpt-3.5-turbo") == "exact-tiktoken"
|
||||
assert get_tokenization_method("gpt-4-turbo") == "exact-tiktoken"
|
||||
|
||||
def test_get_tokenization_method_llama(self):
|
||||
"""Test getting tokenization method for Llama models."""
|
||||
assert (
|
||||
get_tokenization_method("meta-llama/Llama-3.1-8B-Instruct")
|
||||
== "exact-transformers"
|
||||
)
|
||||
assert (
|
||||
get_tokenization_method("meta-llama/Llama-4-Scout-17B-16E-Instruct")
|
||||
== "exact-transformers"
|
||||
)
|
||||
assert (
|
||||
get_tokenization_method("meta-llama/Meta-Llama-3-8B")
|
||||
== "exact-transformers"
|
||||
)
|
||||
|
||||
def test_get_tokenization_method_unknown(self):
|
||||
"""Test getting tokenization method for unknown models."""
|
||||
assert get_tokenization_method("unknown-model") == "estimated"
|
||||
assert get_tokenization_method("claude-3") == "estimated"
|
||||
assert get_tokenization_method("random-model-xyz") == "estimated"
|
||||
|
||||
def test_get_tokenization_method_fine_tuned(self):
|
||||
"""Test getting tokenization method for fine-tuned models."""
|
||||
# Fine-tuned OpenAI models should still use tiktoken
|
||||
assert (
|
||||
get_tokenization_method("gpt-4-turbo-2024-04-09") == "exact-tiktoken"
|
||||
)
|
||||
|
||||
|
||||
class TestClearTokenizerCache:
|
||||
"""Test suite for clear_tokenizer_cache function."""
|
||||
|
||||
def test_clear_tokenizer_cache(self):
|
||||
"""Test clearing tokenizer cache."""
|
||||
# Count tokens to populate cache
|
||||
message = {"role": "user", "content": "Test cache clearing."}
|
||||
count_tokens(message, model="gpt-4")
|
||||
|
||||
# Clear cache
|
||||
clear_tokenizer_cache()
|
||||
|
||||
# Should still work after clearing
|
||||
token_count = count_tokens(message, model="gpt-4")
|
||||
assert token_count > 0
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test suite for edge cases and error handling."""
|
||||
|
||||
def test_empty_string_content(self):
|
||||
"""Test with empty string content."""
|
||||
message = {"role": "user", "content": ""}
|
||||
token_count = count_tokens(message, model="gpt-4")
|
||||
assert token_count == 0
|
||||
|
||||
def test_whitespace_only_content(self):
|
||||
"""Test with whitespace-only content."""
|
||||
message = {"role": "user", "content": " \n\t "}
|
||||
token_count = count_tokens(message, model="gpt-4")
|
||||
# Should count whitespace tokens
|
||||
assert token_count >= 0
|
||||
|
||||
def test_unicode_content(self):
|
||||
"""Test with unicode content."""
|
||||
message = {"role": "user", "content": "Hello 世界 🌍"}
|
||||
token_count = count_tokens(message, model="gpt-4")
|
||||
assert token_count > 0
|
||||
|
||||
def test_multimodal_empty_text(self):
|
||||
"""Test multimodal message with empty text."""
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": ""},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://example.com/image.jpg"},
|
||||
},
|
||||
],
|
||||
}
|
||||
token_count = count_tokens(message, model="gpt-4o")
|
||||
# Should only count image tokens
|
||||
assert token_count > 0
|
||||
|
||||
def test_multimodal_missing_text_field(self):
|
||||
"""Test multimodal message with missing text field."""
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text"}, # Missing 'text' field
|
||||
],
|
||||
}
|
||||
token_count = count_tokens(message, model="gpt-4o")
|
||||
# Should handle gracefully
|
||||
assert token_count == 0
|
||||
|
||||
def test_multimodal_unknown_type(self):
|
||||
"""Test multimodal message with unknown content type."""
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "unknown", "data": "something"},
|
||||
{"type": "text", "text": "Hello"},
|
||||
],
|
||||
}
|
||||
token_count = count_tokens(message, model="gpt-4o")
|
||||
# Should only count known types
|
||||
assert token_count > 0
|
||||
|
||||
def test_nested_content_structures(self):
|
||||
"""Test with nested content structures."""
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "First part",
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Second part",
|
||||
},
|
||||
],
|
||||
}
|
||||
token_count = count_tokens(message, model="gpt-4")
|
||||
# Should count all text parts
|
||||
assert token_count > 0
|
||||
|
||||
def test_consistency_across_calls(self):
|
||||
"""Test that token counting is consistent across calls."""
|
||||
message = {"role": "user", "content": "Consistency test message."}
|
||||
|
||||
count1 = count_tokens(message, model="gpt-4")
|
||||
count2 = count_tokens(message, model="gpt-4")
|
||||
|
||||
assert count1 == count2
|
||||
|
||||
|
||||
class TestPerformance:
|
||||
"""Test suite for performance characteristics."""
|
||||
|
||||
def test_tokenizer_caching_works(self):
|
||||
"""Test that tokenizer caching improves performance."""
|
||||
message = {"role": "user", "content": "Test caching performance."}
|
||||
|
||||
# First call loads tokenizer
|
||||
count_tokens(message, model="gpt-4")
|
||||
|
||||
# Subsequent calls should use cached tokenizer
|
||||
# (We can't easily measure time in unit tests, but we verify it works)
|
||||
for _ in range(5):
|
||||
token_count = count_tokens(message, model="gpt-4")
|
||||
assert token_count > 0
|
||||
|
||||
def test_cache_size_limit(self):
|
||||
"""Test that cache size is limited (max 10 tokenizers)."""
|
||||
# Load more than 10 different models (using estimation for most)
|
||||
models = [f"model-{i}" for i in range(15)]
|
||||
|
||||
message = {"role": "user", "content": "Test"}
|
||||
|
||||
for model in models:
|
||||
count_tokens(message, model=model, exact=False)
|
||||
|
||||
# Should still work (cache evicts oldest entries)
|
||||
token_count = count_tokens(message, model="model-0", exact=False)
|
||||
assert token_count > 0
|
||||
Loading…
Add table
Add a link
Reference in a new issue