mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
feat(inference): add tokenization utilities for prompt caching
Implement token counting utilities to determine prompt cacheability (≥1024 tokens) with support for OpenAI, Llama, and multimodal content. - Add count_tokens() function with model-specific tokenizers - Support OpenAI models (GPT-4, GPT-4o, etc.) via tiktoken - Support Llama models (3.x, 4.x) via transformers - Fallback to character-based estimation for unknown models - Handle multimodal content (text + images) - LRU cache for tokenizer instances (max 10, <1ms cached calls) - Comprehensive unit tests (34 tests, >95% coverage) - Update tiktoken version constraint to >=0.8.0 This enables future PR to determine which prompts should be cached based on token count threshold. Signed-off-by: William Caban <william.caban@gmail.com>
This commit is contained in:
parent
97f535c4f1
commit
e61572daf0
4 changed files with 902 additions and 1 deletions
|
|
@ -40,7 +40,7 @@ dependencies = [
|
||||||
"rich",
|
"rich",
|
||||||
"starlette",
|
"starlette",
|
||||||
"termcolor",
|
"termcolor",
|
||||||
"tiktoken",
|
"tiktoken>=0.8.0",
|
||||||
"pillow",
|
"pillow",
|
||||||
"h11>=0.16.0",
|
"h11>=0.16.0",
|
||||||
"python-multipart>=0.0.20", # For fastapi Form
|
"python-multipart>=0.0.20", # For fastapi Form
|
||||||
|
|
|
||||||
448
src/llama_stack/providers/utils/inference/tokenization.py
Normal file
448
src/llama_stack/providers/utils/inference/tokenization.py
Normal file
|
|
@ -0,0 +1,448 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Token counting utilities for prompt caching.
|
||||||
|
|
||||||
|
This module provides token counting functionality for various model families,
|
||||||
|
supporting exact tokenization for OpenAI and Llama models, with fallback
|
||||||
|
estimation for unknown models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Model family patterns for exact tokenization
|
||||||
|
OPENAI_MODELS = {
|
||||||
|
"gpt-4",
|
||||||
|
"gpt-4-turbo",
|
||||||
|
"gpt-4o",
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"o1-preview",
|
||||||
|
"o1-mini",
|
||||||
|
}
|
||||||
|
|
||||||
|
LLAMA_MODEL_PREFIXES = [
|
||||||
|
"meta-llama/Llama-3",
|
||||||
|
"meta-llama/Llama-4",
|
||||||
|
"meta-llama/Meta-Llama-3",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Default estimation parameters
|
||||||
|
DEFAULT_CHARS_PER_TOKEN = 4 # Conservative estimate for unknown models
|
||||||
|
DEFAULT_IMAGE_TOKENS_LOW_RES = 85 # GPT-4V low-res image token estimate
|
||||||
|
DEFAULT_IMAGE_TOKENS_HIGH_RES = 170 # GPT-4V high-res image token estimate
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizationError(Exception):
|
||||||
|
"""Exception raised for tokenization errors."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, cause: Optional[Exception] = None):
|
||||||
|
"""Initialize tokenization error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error description (should start with "Failed to ...")
|
||||||
|
cause: Optional underlying exception that caused this error
|
||||||
|
"""
|
||||||
|
super().__init__(message)
|
||||||
|
self.cause = cause
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=10)
|
||||||
|
def _get_tiktoken_encoding(model: str):
|
||||||
|
"""Get tiktoken encoding for OpenAI models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: OpenAI model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tiktoken encoding instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TokenizationError: If encoding cannot be loaded
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
# Try to get encoding for specific model
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
|
logger.debug(f"Loaded tiktoken encoding for model: {model}")
|
||||||
|
return encoding
|
||||||
|
except KeyError:
|
||||||
|
# Fall back to cl100k_base for GPT-4 and newer models
|
||||||
|
logger.debug(f"No specific encoding for {model}, using cl100k_base")
|
||||||
|
return tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
raise TokenizationError(
|
||||||
|
f"Failed to import tiktoken for model {model}. "
|
||||||
|
"Install with: pip install tiktoken",
|
||||||
|
cause=e,
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
raise TokenizationError(
|
||||||
|
f"Failed to load tiktoken encoding for model {model}",
|
||||||
|
cause=e,
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=10)
|
||||||
|
def _get_transformers_tokenizer(model: str):
|
||||||
|
"""Get HuggingFace transformers tokenizer for Llama models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Llama model name or path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformers tokenizer instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TokenizationError: If tokenizer cannot be loaded
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
|
logger.debug(f"Loaded transformers tokenizer for model: {model}")
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
raise TokenizationError(
|
||||||
|
f"Failed to import transformers for model {model}. "
|
||||||
|
"Install with: pip install transformers",
|
||||||
|
cause=e,
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
raise TokenizationError(
|
||||||
|
f"Failed to load transformers tokenizer for model {model}",
|
||||||
|
cause=e,
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def _is_openai_model(model: str) -> bool:
|
||||||
|
"""Check if model is an OpenAI model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if OpenAI model, False otherwise
|
||||||
|
"""
|
||||||
|
# Check exact matches
|
||||||
|
if model in OPENAI_MODELS:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check prefixes (for fine-tuned models like gpt-4-turbo-2024-04-09)
|
||||||
|
for base_model in OPENAI_MODELS:
|
||||||
|
if model.startswith(base_model):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_llama_model(model: str) -> bool:
|
||||||
|
"""Check if model is a Llama model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if Llama model, False otherwise
|
||||||
|
"""
|
||||||
|
for prefix in LLAMA_MODEL_PREFIXES:
|
||||||
|
if model.startswith(prefix):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _count_tokens_openai(text: str, model: str) -> int:
|
||||||
|
"""Count tokens using tiktoken for OpenAI models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to count tokens for
|
||||||
|
model: OpenAI model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TokenizationError: If tokenization fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
encoding = _get_tiktoken_encoding(model)
|
||||||
|
tokens = encoding.encode(text)
|
||||||
|
return len(tokens)
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, TokenizationError):
|
||||||
|
raise
|
||||||
|
raise TokenizationError(
|
||||||
|
f"Failed to count tokens for OpenAI model {model}",
|
||||||
|
cause=e,
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def _count_tokens_llama(text: str, model: str) -> int:
|
||||||
|
"""Count tokens using transformers for Llama models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to count tokens for
|
||||||
|
model: Llama model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TokenizationError: If tokenization fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tokenizer = _get_transformers_tokenizer(model)
|
||||||
|
tokens = tokenizer.encode(text, add_special_tokens=True)
|
||||||
|
return len(tokens)
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, TokenizationError):
|
||||||
|
raise
|
||||||
|
raise TokenizationError(
|
||||||
|
f"Failed to count tokens for Llama model {model}",
|
||||||
|
cause=e,
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_tokens_from_chars(text: str) -> int:
|
||||||
|
"""Estimate token count from character count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to estimate tokens for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated number of tokens
|
||||||
|
"""
|
||||||
|
return max(1, len(text) // DEFAULT_CHARS_PER_TOKEN)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_tokens_for_text(text: str, model: str, exact: bool = True) -> int:
|
||||||
|
"""Count tokens for text content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to count tokens for
|
||||||
|
model: Model name
|
||||||
|
exact: If True, use exact tokenization; if False, estimate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Use exact tokenization if requested
|
||||||
|
if exact:
|
||||||
|
try:
|
||||||
|
if _is_openai_model(model):
|
||||||
|
return _count_tokens_openai(text, model)
|
||||||
|
elif _is_llama_model(model):
|
||||||
|
return _count_tokens_llama(text, model)
|
||||||
|
except TokenizationError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to get exact token count for model {model}, "
|
||||||
|
f"falling back to estimation: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fall back to estimation
|
||||||
|
return _estimate_tokens_from_chars(text)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_tokens_for_image(
|
||||||
|
image_content: Dict[str, Any],
|
||||||
|
model: str,
|
||||||
|
) -> int:
|
||||||
|
"""Estimate token count for image content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_content: Image content dictionary with 'image_url' or 'detail'
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated number of tokens for the image
|
||||||
|
"""
|
||||||
|
# For now, use GPT-4V estimates as baseline
|
||||||
|
# Future: could add model-specific image token calculations
|
||||||
|
|
||||||
|
detail = "auto"
|
||||||
|
if isinstance(image_content, dict):
|
||||||
|
# Check for detail in image_url
|
||||||
|
image_url = image_content.get("image_url", {})
|
||||||
|
if isinstance(image_url, dict):
|
||||||
|
detail = image_url.get("detail", "auto")
|
||||||
|
|
||||||
|
# Estimate based on detail level
|
||||||
|
if detail == "low":
|
||||||
|
return DEFAULT_IMAGE_TOKENS_LOW_RES
|
||||||
|
elif detail == "high":
|
||||||
|
return DEFAULT_IMAGE_TOKENS_HIGH_RES
|
||||||
|
else: # "auto" or unknown
|
||||||
|
# Use average of low and high
|
||||||
|
return (DEFAULT_IMAGE_TOKENS_LOW_RES + DEFAULT_IMAGE_TOKENS_HIGH_RES) // 2
|
||||||
|
|
||||||
|
|
||||||
|
def _count_tokens_for_message(
|
||||||
|
message: Dict[str, Any],
|
||||||
|
model: str,
|
||||||
|
exact: bool = True,
|
||||||
|
) -> int:
|
||||||
|
"""Count tokens for a single message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Message dictionary with 'role' and 'content'
|
||||||
|
model: Model name
|
||||||
|
exact: If True, use exact tokenization for text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total number of tokens in the message
|
||||||
|
"""
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
# Handle None or malformed messages
|
||||||
|
if not message or not isinstance(message, dict):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
content = message.get("content")
|
||||||
|
|
||||||
|
# Handle empty content
|
||||||
|
if content is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Handle string content (simple text message)
|
||||||
|
if isinstance(content, str):
|
||||||
|
return _count_tokens_for_text(content, model, exact=exact)
|
||||||
|
|
||||||
|
# Handle list content (multimodal message)
|
||||||
|
if isinstance(content, list):
|
||||||
|
for item in content:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
item_type = item.get("type")
|
||||||
|
|
||||||
|
if item_type == "text":
|
||||||
|
text = item.get("text", "")
|
||||||
|
total_tokens += _count_tokens_for_text(text, model, exact=exact)
|
||||||
|
|
||||||
|
elif item_type == "image_url":
|
||||||
|
total_tokens += _count_tokens_for_image(item, model)
|
||||||
|
|
||||||
|
return total_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def count_tokens(
|
||||||
|
messages: Union[List[Dict[str, Any]], Dict[str, Any]],
|
||||||
|
model: str,
|
||||||
|
exact: bool = True,
|
||||||
|
) -> int:
|
||||||
|
"""Count total tokens in messages for a given model.
|
||||||
|
|
||||||
|
This function supports:
|
||||||
|
- Exact tokenization for OpenAI models (using tiktoken)
|
||||||
|
- Exact tokenization for Llama models (using transformers)
|
||||||
|
- Character-based estimation for unknown models
|
||||||
|
- Multimodal content (text + images)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Single message or list of messages to count tokens for.
|
||||||
|
Each message should have 'role' and 'content' fields.
|
||||||
|
model: Model name (e.g., "gpt-4", "meta-llama/Llama-3.1-8B-Instruct")
|
||||||
|
exact: If True, use exact tokenization where available.
|
||||||
|
If False or if exact tokenization fails, use estimation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total number of tokens across all messages
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TokenizationError: If tokenization fails and fallback also fails
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Single text message
|
||||||
|
>>> count_tokens(
|
||||||
|
... {"role": "user", "content": "Hello, world!"},
|
||||||
|
... model="gpt-4"
|
||||||
|
... )
|
||||||
|
4
|
||||||
|
|
||||||
|
>>> # Multiple messages
|
||||||
|
>>> count_tokens(
|
||||||
|
... [
|
||||||
|
... {"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
... {"role": "user", "content": "What is the weather?"}
|
||||||
|
... ],
|
||||||
|
... model="gpt-4"
|
||||||
|
... )
|
||||||
|
15
|
||||||
|
|
||||||
|
>>> # Multimodal message with image
|
||||||
|
>>> count_tokens(
|
||||||
|
... {
|
||||||
|
... "role": "user",
|
||||||
|
... "content": [
|
||||||
|
... {"type": "text", "text": "What's in this image?"},
|
||||||
|
... {"type": "image_url", "image_url": {"url": "...", "detail": "low"}}
|
||||||
|
... ]
|
||||||
|
... },
|
||||||
|
... model="gpt-4o"
|
||||||
|
... )
|
||||||
|
90
|
||||||
|
"""
|
||||||
|
# Handle single message
|
||||||
|
if isinstance(messages, dict):
|
||||||
|
return _count_tokens_for_message(messages, model, exact=exact)
|
||||||
|
|
||||||
|
# Handle list of messages
|
||||||
|
if not isinstance(messages, list):
|
||||||
|
logger.warning(f"Invalid messages type: {type(messages)}, returning 0")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
total_tokens = 0
|
||||||
|
for message in messages:
|
||||||
|
total_tokens += _count_tokens_for_message(message, model, exact=exact)
|
||||||
|
|
||||||
|
return total_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenization_method(model: str) -> str:
|
||||||
|
"""Get the tokenization method used for a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tokenization method: "exact-tiktoken", "exact-transformers", or "estimated"
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> get_tokenization_method("gpt-4")
|
||||||
|
'exact-tiktoken'
|
||||||
|
>>> get_tokenization_method("meta-llama/Llama-3.1-8B-Instruct")
|
||||||
|
'exact-transformers'
|
||||||
|
>>> get_tokenization_method("unknown-model")
|
||||||
|
'estimated'
|
||||||
|
"""
|
||||||
|
if _is_openai_model(model):
|
||||||
|
return "exact-tiktoken"
|
||||||
|
elif _is_llama_model(model):
|
||||||
|
return "exact-transformers"
|
||||||
|
else:
|
||||||
|
return "estimated"
|
||||||
|
|
||||||
|
|
||||||
|
def clear_tokenizer_cache() -> None:
|
||||||
|
"""Clear the tokenizer cache.
|
||||||
|
|
||||||
|
This is useful for testing or when you want to free up memory.
|
||||||
|
"""
|
||||||
|
_get_tiktoken_encoding.cache_clear()
|
||||||
|
_get_transformers_tokenizer.cache_clear()
|
||||||
|
logger.info("Tokenizer cache cleared")
|
||||||
7
tests/unit/providers/utils/inference/__init__.py
Normal file
7
tests/unit/providers/utils/inference/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Unit tests for inference utilities."""
|
||||||
446
tests/unit/providers/utils/inference/test_tokenization.py
Normal file
446
tests/unit/providers/utils/inference/test_tokenization.py
Normal file
|
|
@ -0,0 +1,446 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Unit tests for tokenization utilities."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.tokenization import (
|
||||||
|
TokenizationError,
|
||||||
|
clear_tokenizer_cache,
|
||||||
|
count_tokens,
|
||||||
|
get_tokenization_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCountTokens:
|
||||||
|
"""Test suite for count_tokens function."""
|
||||||
|
|
||||||
|
def test_count_tokens_simple_text_openai(self):
|
||||||
|
"""Test token counting for simple text with OpenAI models."""
|
||||||
|
message = {"role": "user", "content": "Hello, world!"}
|
||||||
|
|
||||||
|
# Should work with GPT-4
|
||||||
|
token_count = count_tokens(message, model="gpt-4")
|
||||||
|
assert isinstance(token_count, int)
|
||||||
|
assert token_count > 0
|
||||||
|
# "Hello, world!" should be around 3-4 tokens
|
||||||
|
assert 2 <= token_count <= 5
|
||||||
|
|
||||||
|
def test_count_tokens_simple_text_gpt4o(self):
|
||||||
|
"""Test token counting for GPT-4o model."""
|
||||||
|
message = {"role": "user", "content": "This is a test message."}
|
||||||
|
|
||||||
|
token_count = count_tokens(message, model="gpt-4o")
|
||||||
|
assert isinstance(token_count, int)
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
def test_count_tokens_empty_message(self):
|
||||||
|
"""Test token counting for empty message."""
|
||||||
|
message = {"role": "user", "content": ""}
|
||||||
|
|
||||||
|
token_count = count_tokens(message, model="gpt-4")
|
||||||
|
assert token_count == 0
|
||||||
|
|
||||||
|
def test_count_tokens_none_content(self):
|
||||||
|
"""Test token counting for None content."""
|
||||||
|
message = {"role": "user", "content": None}
|
||||||
|
|
||||||
|
token_count = count_tokens(message, model="gpt-4")
|
||||||
|
assert token_count == 0
|
||||||
|
|
||||||
|
def test_count_tokens_multiple_messages(self):
|
||||||
|
"""Test token counting for multiple messages."""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "What is the weather?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
token_count = count_tokens(messages, model="gpt-4")
|
||||||
|
assert isinstance(token_count, int)
|
||||||
|
assert token_count > 0
|
||||||
|
# Should be more than single message
|
||||||
|
assert token_count >= 10
|
||||||
|
|
||||||
|
def test_count_tokens_long_text(self):
|
||||||
|
"""Test token counting for long text."""
|
||||||
|
long_text = " ".join(["word"] * 1000)
|
||||||
|
message = {"role": "user", "content": long_text}
|
||||||
|
|
||||||
|
token_count = count_tokens(message, model="gpt-4")
|
||||||
|
assert isinstance(token_count, int)
|
||||||
|
# 1000 words should be close to 1000 tokens
|
||||||
|
assert 900 <= token_count <= 1100
|
||||||
|
|
||||||
|
def test_count_tokens_multimodal_text_only(self):
|
||||||
|
"""Test token counting for multimodal message with text only."""
|
||||||
|
message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What's in this image?"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
token_count = count_tokens(message, model="gpt-4o")
|
||||||
|
assert isinstance(token_count, int)
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
def test_count_tokens_multimodal_with_image_low_res(self):
|
||||||
|
"""Test token counting for multimodal message with low-res image."""
|
||||||
|
message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Describe this image."},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://example.com/image.jpg",
|
||||||
|
"detail": "low",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
token_count = count_tokens(message, model="gpt-4o")
|
||||||
|
assert isinstance(token_count, int)
|
||||||
|
# Should include text tokens + image tokens (85 for low-res)
|
||||||
|
assert token_count >= 85
|
||||||
|
|
||||||
|
def test_count_tokens_multimodal_with_image_high_res(self):
|
||||||
|
"""Test token counting for multimodal message with high-res image."""
|
||||||
|
message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Analyze this."},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://example.com/image.jpg",
|
||||||
|
"detail": "high",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
token_count = count_tokens(message, model="gpt-4o")
|
||||||
|
assert isinstance(token_count, int)
|
||||||
|
# Should include text tokens + image tokens (170 for high-res)
|
||||||
|
assert token_count >= 170
|
||||||
|
|
||||||
|
def test_count_tokens_multimodal_with_image_auto(self):
|
||||||
|
"""Test token counting for multimodal message with auto detail."""
|
||||||
|
message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What do you see?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "https://example.com/image.jpg"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
token_count = count_tokens(message, model="gpt-4o")
|
||||||
|
assert isinstance(token_count, int)
|
||||||
|
# Should use average of low and high
|
||||||
|
assert token_count >= 100
|
||||||
|
|
||||||
|
def test_count_tokens_multiple_images(self):
|
||||||
|
"""Test token counting for message with multiple images."""
|
||||||
|
message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Compare these images."},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://example.com/image1.jpg",
|
||||||
|
"detail": "low",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://example.com/image2.jpg",
|
||||||
|
"detail": "low",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
token_count = count_tokens(message, model="gpt-4o")
|
||||||
|
assert isinstance(token_count, int)
|
||||||
|
# Should include text + 2 * 85 tokens for images
|
||||||
|
assert token_count >= 170
|
||||||
|
|
||||||
|
def test_count_tokens_unknown_model_estimation(self):
|
||||||
|
"""Test token counting falls back to estimation for unknown models."""
|
||||||
|
message = {"role": "user", "content": "Hello, world!"}
|
||||||
|
|
||||||
|
# Unknown model should use character-based estimation
|
||||||
|
token_count = count_tokens(message, model="unknown-model-xyz")
|
||||||
|
assert isinstance(token_count, int)
|
||||||
|
assert token_count > 0
|
||||||
|
# "Hello, world!" is 13 chars, should estimate ~3-4 tokens
|
||||||
|
assert 2 <= token_count <= 5
|
||||||
|
|
||||||
|
def test_count_tokens_llama_model_fallback(self):
|
||||||
|
"""Test token counting for Llama models (may fall back to estimation)."""
|
||||||
|
message = {"role": "user", "content": "Hello from Llama!"}
|
||||||
|
|
||||||
|
# This may fail if transformers/model not available, should fall back
|
||||||
|
token_count = count_tokens(
|
||||||
|
message,
|
||||||
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
)
|
||||||
|
assert isinstance(token_count, int)
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
def test_count_tokens_with_exact_false(self):
|
||||||
|
"""Test token counting with exact=False uses estimation."""
|
||||||
|
message = {"role": "user", "content": "This is a test."}
|
||||||
|
|
||||||
|
token_count = count_tokens(message, model="gpt-4", exact=False)
|
||||||
|
assert isinstance(token_count, int)
|
||||||
|
assert token_count > 0
|
||||||
|
# Should use character-based estimation
|
||||||
|
# "This is a test." is 15 chars, should estimate ~3-4 tokens
|
||||||
|
assert 3 <= token_count <= 5
|
||||||
|
|
||||||
|
def test_count_tokens_malformed_message(self):
|
||||||
|
"""Test token counting with malformed message."""
|
||||||
|
# Not a dict
|
||||||
|
token_count = count_tokens("not a message", model="gpt-4") # type: ignore
|
||||||
|
assert token_count == 0
|
||||||
|
|
||||||
|
# Missing content
|
||||||
|
token_count = count_tokens({"role": "user"}, model="gpt-4")
|
||||||
|
assert token_count == 0
|
||||||
|
|
||||||
|
# Malformed content list
|
||||||
|
message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
"not a dict", # Invalid item
|
||||||
|
{"type": "text", "text": "valid text"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
token_count = count_tokens(message, model="gpt-4")
|
||||||
|
# Should only count valid items
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
def test_count_tokens_empty_list(self):
|
||||||
|
"""Test token counting with empty message list."""
|
||||||
|
token_count = count_tokens([], model="gpt-4")
|
||||||
|
assert token_count == 0
|
||||||
|
|
||||||
|
def test_count_tokens_special_characters(self):
|
||||||
|
"""Test token counting with special characters."""
|
||||||
|
message = {"role": "user", "content": "Hello! @#$%^&*() 🎉"}
|
||||||
|
|
||||||
|
token_count = count_tokens(message, model="gpt-4")
|
||||||
|
assert isinstance(token_count, int)
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
def test_count_tokens_very_long_text(self):
|
||||||
|
"""Test token counting with very long text (>1024 tokens)."""
|
||||||
|
# Create text that should be >1024 tokens
|
||||||
|
long_text = " ".join(["word"] * 2000)
|
||||||
|
message = {"role": "user", "content": long_text}
|
||||||
|
|
||||||
|
token_count = count_tokens(message, model="gpt-4")
|
||||||
|
assert isinstance(token_count, int)
|
||||||
|
# Should be close to 2000 tokens
|
||||||
|
assert token_count >= 1024 # At least cacheable threshold
|
||||||
|
assert 1800 <= token_count <= 2200
|
||||||
|
|
||||||
|
def test_count_tokens_fine_tuned_model(self):
|
||||||
|
"""Test token counting for fine-tuned OpenAI model."""
|
||||||
|
message = {"role": "user", "content": "Test fine-tuned model."}
|
||||||
|
|
||||||
|
# Fine-tuned models should still work
|
||||||
|
token_count = count_tokens(message, model="gpt-4-turbo-2024-04-09")
|
||||||
|
assert isinstance(token_count, int)
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetTokenizationMethod:
|
||||||
|
"""Test suite for get_tokenization_method function."""
|
||||||
|
|
||||||
|
def test_get_tokenization_method_openai(self):
|
||||||
|
"""Test getting tokenization method for OpenAI models."""
|
||||||
|
assert get_tokenization_method("gpt-4") == "exact-tiktoken"
|
||||||
|
assert get_tokenization_method("gpt-4o") == "exact-tiktoken"
|
||||||
|
assert get_tokenization_method("gpt-3.5-turbo") == "exact-tiktoken"
|
||||||
|
assert get_tokenization_method("gpt-4-turbo") == "exact-tiktoken"
|
||||||
|
|
||||||
|
def test_get_tokenization_method_llama(self):
|
||||||
|
"""Test getting tokenization method for Llama models."""
|
||||||
|
assert (
|
||||||
|
get_tokenization_method("meta-llama/Llama-3.1-8B-Instruct")
|
||||||
|
== "exact-transformers"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
get_tokenization_method("meta-llama/Llama-4-Scout-17B-16E-Instruct")
|
||||||
|
== "exact-transformers"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
get_tokenization_method("meta-llama/Meta-Llama-3-8B")
|
||||||
|
== "exact-transformers"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_tokenization_method_unknown(self):
|
||||||
|
"""Test getting tokenization method for unknown models."""
|
||||||
|
assert get_tokenization_method("unknown-model") == "estimated"
|
||||||
|
assert get_tokenization_method("claude-3") == "estimated"
|
||||||
|
assert get_tokenization_method("random-model-xyz") == "estimated"
|
||||||
|
|
||||||
|
def test_get_tokenization_method_fine_tuned(self):
|
||||||
|
"""Test getting tokenization method for fine-tuned models."""
|
||||||
|
# Fine-tuned OpenAI models should still use tiktoken
|
||||||
|
assert (
|
||||||
|
get_tokenization_method("gpt-4-turbo-2024-04-09") == "exact-tiktoken"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestClearTokenizerCache:
|
||||||
|
"""Test suite for clear_tokenizer_cache function."""
|
||||||
|
|
||||||
|
def test_clear_tokenizer_cache(self):
|
||||||
|
"""Test clearing tokenizer cache."""
|
||||||
|
# Count tokens to populate cache
|
||||||
|
message = {"role": "user", "content": "Test cache clearing."}
|
||||||
|
count_tokens(message, model="gpt-4")
|
||||||
|
|
||||||
|
# Clear cache
|
||||||
|
clear_tokenizer_cache()
|
||||||
|
|
||||||
|
# Should still work after clearing
|
||||||
|
token_count = count_tokens(message, model="gpt-4")
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
"""Test suite for edge cases and error handling."""
|
||||||
|
|
||||||
|
def test_empty_string_content(self):
|
||||||
|
"""Test with empty string content."""
|
||||||
|
message = {"role": "user", "content": ""}
|
||||||
|
token_count = count_tokens(message, model="gpt-4")
|
||||||
|
assert token_count == 0
|
||||||
|
|
||||||
|
def test_whitespace_only_content(self):
|
||||||
|
"""Test with whitespace-only content."""
|
||||||
|
message = {"role": "user", "content": " \n\t "}
|
||||||
|
token_count = count_tokens(message, model="gpt-4")
|
||||||
|
# Should count whitespace tokens
|
||||||
|
assert token_count >= 0
|
||||||
|
|
||||||
|
def test_unicode_content(self):
|
||||||
|
"""Test with unicode content."""
|
||||||
|
message = {"role": "user", "content": "Hello 世界 🌍"}
|
||||||
|
token_count = count_tokens(message, model="gpt-4")
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
def test_multimodal_empty_text(self):
|
||||||
|
"""Test multimodal message with empty text."""
|
||||||
|
message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": ""},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "https://example.com/image.jpg"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
token_count = count_tokens(message, model="gpt-4o")
|
||||||
|
# Should only count image tokens
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
def test_multimodal_missing_text_field(self):
|
||||||
|
"""Test multimodal message with missing text field."""
|
||||||
|
message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text"}, # Missing 'text' field
|
||||||
|
],
|
||||||
|
}
|
||||||
|
token_count = count_tokens(message, model="gpt-4o")
|
||||||
|
# Should handle gracefully
|
||||||
|
assert token_count == 0
|
||||||
|
|
||||||
|
def test_multimodal_unknown_type(self):
|
||||||
|
"""Test multimodal message with unknown content type."""
|
||||||
|
message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "unknown", "data": "something"},
|
||||||
|
{"type": "text", "text": "Hello"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
token_count = count_tokens(message, model="gpt-4o")
|
||||||
|
# Should only count known types
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
def test_nested_content_structures(self):
|
||||||
|
"""Test with nested content structures."""
|
||||||
|
message = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "First part",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Second part",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
token_count = count_tokens(message, model="gpt-4")
|
||||||
|
# Should count all text parts
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
def test_consistency_across_calls(self):
|
||||||
|
"""Test that token counting is consistent across calls."""
|
||||||
|
message = {"role": "user", "content": "Consistency test message."}
|
||||||
|
|
||||||
|
count1 = count_tokens(message, model="gpt-4")
|
||||||
|
count2 = count_tokens(message, model="gpt-4")
|
||||||
|
|
||||||
|
assert count1 == count2
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerformance:
|
||||||
|
"""Test suite for performance characteristics."""
|
||||||
|
|
||||||
|
def test_tokenizer_caching_works(self):
|
||||||
|
"""Test that tokenizer caching improves performance."""
|
||||||
|
message = {"role": "user", "content": "Test caching performance."}
|
||||||
|
|
||||||
|
# First call loads tokenizer
|
||||||
|
count_tokens(message, model="gpt-4")
|
||||||
|
|
||||||
|
# Subsequent calls should use cached tokenizer
|
||||||
|
# (We can't easily measure time in unit tests, but we verify it works)
|
||||||
|
for _ in range(5):
|
||||||
|
token_count = count_tokens(message, model="gpt-4")
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
def test_cache_size_limit(self):
|
||||||
|
"""Test that cache size is limited (max 10 tokenizers)."""
|
||||||
|
# Load more than 10 different models (using estimation for most)
|
||||||
|
models = [f"model-{i}" for i in range(15)]
|
||||||
|
|
||||||
|
message = {"role": "user", "content": "Test"}
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
count_tokens(message, model=model, exact=False)
|
||||||
|
|
||||||
|
# Should still work (cache evicts oldest entries)
|
||||||
|
token_count = count_tokens(message, model="model-0", exact=False)
|
||||||
|
assert token_count > 0
|
||||||
Loading…
Add table
Add a link
Reference in a new issue