chore: remove usage of load_tiktoken_bpe (#2276)

This commit is contained in:
Sébastien Han 2025-06-02 16:33:37 +02:00 committed by GitHub
parent af65207ebd
commit 1c0c6e1e17
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 234 additions and 17 deletions

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger
from pathlib import Path
@ -14,7 +13,8 @@ from typing import (
)
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
logger = getLogger(__name__)
@ -48,19 +48,20 @@ class Tokenizer:
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
_INSTANCE = Tokenizer(Path(__file__).parent / "tokenizer.model")
return _INSTANCE
def __init__(self, model_path: str):
def __init__(self, model_path: Path):
"""
Initializes the Tokenizer with a Tiktoken model.
Args:
model_path (str): The path to the Tiktoken model file.
"""
assert os.path.isfile(model_path), model_path
if not model_path.exists():
raise FileNotFoundError(f"Tokenizer model file not found: {model_path}")
mergeable_ranks = load_tiktoken_bpe(model_path)
mergeable_ranks = load_bpe_file(model_path)
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
@ -83,7 +84,7 @@ class Tokenizer:
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
name=model_path.name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger
from pathlib import Path
@ -14,7 +13,8 @@ from typing import (
)
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
logger = getLogger(__name__)
@ -118,19 +118,20 @@ class Tokenizer:
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
_INSTANCE = Tokenizer(Path(__file__).parent / "tokenizer.model")
return _INSTANCE
def __init__(self, model_path: str):
def __init__(self, model_path: Path):
"""
Initializes the Tokenizer with a Tiktoken model.
Args:
model_path (str): The path to the Tiktoken model file.
model_path (Path): The path to the Tiktoken model file.
"""
assert os.path.isfile(model_path), model_path
if not model_path.exists():
raise FileNotFoundError(f"Tokenizer model file not found: {model_path}")
mergeable_ranks = load_tiktoken_bpe(model_path)
mergeable_ranks = load_bpe_file(model_path)
num_base_tokens = len(mergeable_ranks)
special_tokens = BASIC_SPECIAL_TOKENS + LLAMA4_SPECIAL_TOKENS
@ -144,7 +145,7 @@ class Tokenizer:
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
name=model_path.name,
pat_str=self.O200K_PATTERN,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
from pathlib import Path
from llama_stack.log import get_logger
logger = get_logger(__name__, "tokenizer_utils")
def load_bpe_file(model_path: Path) -> dict[bytes, int]:
"""
Load BPE file directly and return mergeable ranks.
Args:
model_path (Path): Path to the BPE model file.
Returns:
dict[bytes, int]: Dictionary mapping byte sequences to their ranks.
"""
mergeable_ranks = {}
with open(model_path, encoding="utf-8") as f:
content = f.read()
for line in content.splitlines():
if not line.strip(): # Skip empty lines
continue
try:
token, rank = line.split()
mergeable_ranks[base64.b64decode(token)] = int(rank)
except Exception as e:
logger.warning(f"Failed to parse line '{line}': {e}")
continue
return mergeable_ranks

View file

@ -15,7 +15,6 @@ from llama_stack.providers.datatypes import (
META_REFERENCE_DEPS = [
"accelerate",
"blobfile",
"fairscale",
"torch",
"torchvision",

View file

@ -20,7 +20,6 @@ def available_providers() -> list[ProviderSpec]:
api=Api.tool_runtime,
provider_type="inline::rag-runtime",
pip_packages=[
"blobfile",
"chardet",
"pypdf",
"tqdm",

View file

@ -0,0 +1,177 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import time
from pathlib import Path
from unittest.mock import patch
import pytest
from tiktoken.load import load_tiktoken_bpe
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
@pytest.fixture
def test_bpe_content():
"""Sample BPE file content for testing."""
return """wA== 0
wQ== 1
9Q== 2
9g== 3
9w== 4
+A== 5
+Q== 6
+g== 7
+w== 8
/A== 9
/Q== 10
/g== 11
/w== 12
AA== 13
AQ== 14"""
@pytest.fixture
def test_bpe_file(tmp_path, test_bpe_content):
"""Create a temporary BPE file for testing."""
bpe_file = tmp_path / "test_tokenizer.model"
bpe_file.write_text(test_bpe_content, encoding="utf-8")
return bpe_file
@pytest.fixture
def llama3_model_path():
"""Path to Llama3 tokenizer model."""
return Path(__file__).parent / "../../../../llama_stack/models/llama/llama3/tokenizer.model"
@pytest.fixture
def llama4_model_path():
"""Path to Llama4 tokenizer model."""
return Path(__file__).parent / "../../../../llama_stack/models/llama/llama4/tokenizer.model"
def test_load_bpe_file_basic_functionality(test_bpe_file):
"""Test that load_bpe_file correctly parses BPE files."""
result = load_bpe_file(test_bpe_file)
for key, value in result.items():
assert isinstance(key, bytes)
assert isinstance(value, int)
assert len(result) == 15
expected_first_token = base64.b64decode("wA==")
assert expected_first_token in result
assert result[expected_first_token] == 0
def test_load_bpe_file_vs_tiktoken_with_real_model(llama3_model_path):
"""Test that our implementation produces identical results to tiktoken on real model files."""
if not llama3_model_path.exists():
pytest.skip("Llama3 tokenizer model not found")
our_result = load_bpe_file(llama3_model_path)
tiktoken_result = load_tiktoken_bpe(llama3_model_path.as_posix())
# Compare results from our implementation and tiktoken
assert len(our_result) == len(tiktoken_result)
assert our_result == tiktoken_result
assert len(our_result) > 100000
ranks = list(our_result.values())
assert len(ranks) == len(set(ranks))
def test_load_bpe_file_vs_tiktoken_with_llama4_model(llama4_model_path):
"""Test that our implementation produces identical results to tiktoken on Llama4 model."""
if not llama4_model_path.exists():
pytest.skip("Llama4 tokenizer model not found")
our_result = load_bpe_file(llama4_model_path)
tiktoken_result = load_tiktoken_bpe(llama4_model_path.as_posix())
# Compare results from our implementation and tiktoken
assert len(our_result) == len(tiktoken_result)
assert our_result == tiktoken_result
assert len(our_result) > 100000
ranks = list(our_result.values())
assert len(ranks) == len(set(ranks))
def test_load_bpe_file_malformed_lines(tmp_path):
"""Test that load_bpe_file handles malformed lines gracefully."""
malformed_content = """wA== 0
invalid_line_without_rank
wQ== 1
invalid_base64!!! 2
9Q== 2"""
test_file = tmp_path / "malformed.model"
test_file.write_text(malformed_content, encoding="utf-8")
with patch("llama_stack.models.llama.tokenizer_utils.logger") as mock_logger:
result = load_bpe_file(test_file)
# Should have 3 valid entries (skipping malformed ones)
assert len(result) == 3
# Should have logged warnings for malformed lines
assert mock_logger.warning.called
assert mock_logger.warning.call_count > 0
def test_load_bpe_file_nonexistent_file():
"""Test that load_bpe_file raises appropriate error for nonexistent files."""
with pytest.raises(FileNotFoundError):
load_bpe_file("/nonexistent/path/to/file.model")
def test_tokenizer_integration():
"""Test that our load_bpe_file works correctly when used in actual tokenizers."""
try:
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
tokenizer = Llama3Tokenizer.get_instance()
# Test basic functionality
test_text = "Hello, world! This is a test."
tokens = tokenizer.encode(test_text, bos=False, eos=False)
decoded = tokenizer.decode(tokens)
assert test_text == decoded
assert isinstance(tokens, list)
assert all(isinstance(token, int) for token in tokens)
except Exception as e:
pytest.skip(f"Llama3 tokenizer not available: {e}")
def test_performance_comparison(llama3_model_path):
"""Test that our implementation has reasonable performance compared to tiktoken."""
if not llama3_model_path.exists():
pytest.skip("Llama3 tokenizer model not found")
# Time our implementation
start_time = time.time()
our_result = load_bpe_file(llama3_model_path)
our_time = time.time() - start_time
# Time tiktoken implementation
start_time = time.time()
tiktoken_result = load_tiktoken_bpe(llama3_model_path.as_posix())
tiktoken_time = time.time() - start_time
# Verify results are identical
assert our_result == tiktoken_result
# Our implementation should be reasonably fast (within 10x of tiktoken)
# This is a loose bound since we're optimizing for correctness, not speed
assert our_time < tiktoken_time * 10, f"Our implementation took {our_time:.3f}s vs tiktoken's {tiktoken_time:.3f}s"
print(f"Performance comparison - Our: {our_time:.3f}s, Tiktoken: {tiktoken_time:.3f}s")