mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore: remove usage of load_tiktoken_bpe (#2276)
This commit is contained in:
parent
af65207ebd
commit
1c0c6e1e17
6 changed files with 234 additions and 17 deletions
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
|
||||||
from collections.abc import Collection, Iterator, Sequence, Set
|
from collections.abc import Collection, Iterator, Sequence, Set
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -14,7 +13,8 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from tiktoken.load import load_tiktoken_bpe
|
|
||||||
|
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -48,19 +48,20 @@ class Tokenizer:
|
||||||
global _INSTANCE
|
global _INSTANCE
|
||||||
|
|
||||||
if _INSTANCE is None:
|
if _INSTANCE is None:
|
||||||
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
|
_INSTANCE = Tokenizer(Path(__file__).parent / "tokenizer.model")
|
||||||
return _INSTANCE
|
return _INSTANCE
|
||||||
|
|
||||||
def __init__(self, model_path: str):
|
def __init__(self, model_path: Path):
|
||||||
"""
|
"""
|
||||||
Initializes the Tokenizer with a Tiktoken model.
|
Initializes the Tokenizer with a Tiktoken model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path (str): The path to the Tiktoken model file.
|
model_path (str): The path to the Tiktoken model file.
|
||||||
"""
|
"""
|
||||||
assert os.path.isfile(model_path), model_path
|
if not model_path.exists():
|
||||||
|
raise FileNotFoundError(f"Tokenizer model file not found: {model_path}")
|
||||||
|
|
||||||
mergeable_ranks = load_tiktoken_bpe(model_path)
|
mergeable_ranks = load_bpe_file(model_path)
|
||||||
num_base_tokens = len(mergeable_ranks)
|
num_base_tokens = len(mergeable_ranks)
|
||||||
special_tokens = [
|
special_tokens = [
|
||||||
"<|begin_of_text|>",
|
"<|begin_of_text|>",
|
||||||
|
@ -83,7 +84,7 @@ class Tokenizer:
|
||||||
|
|
||||||
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
||||||
self.model = tiktoken.Encoding(
|
self.model = tiktoken.Encoding(
|
||||||
name=Path(model_path).name,
|
name=model_path.name,
|
||||||
pat_str=self.pat_str,
|
pat_str=self.pat_str,
|
||||||
mergeable_ranks=mergeable_ranks,
|
mergeable_ranks=mergeable_ranks,
|
||||||
special_tokens=self.special_tokens,
|
special_tokens=self.special_tokens,
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
|
||||||
from collections.abc import Collection, Iterator, Sequence, Set
|
from collections.abc import Collection, Iterator, Sequence, Set
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -14,7 +13,8 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from tiktoken.load import load_tiktoken_bpe
|
|
||||||
|
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -118,19 +118,20 @@ class Tokenizer:
|
||||||
global _INSTANCE
|
global _INSTANCE
|
||||||
|
|
||||||
if _INSTANCE is None:
|
if _INSTANCE is None:
|
||||||
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
|
_INSTANCE = Tokenizer(Path(__file__).parent / "tokenizer.model")
|
||||||
return _INSTANCE
|
return _INSTANCE
|
||||||
|
|
||||||
def __init__(self, model_path: str):
|
def __init__(self, model_path: Path):
|
||||||
"""
|
"""
|
||||||
Initializes the Tokenizer with a Tiktoken model.
|
Initializes the Tokenizer with a Tiktoken model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path (str): The path to the Tiktoken model file.
|
model_path (Path): The path to the Tiktoken model file.
|
||||||
"""
|
"""
|
||||||
assert os.path.isfile(model_path), model_path
|
if not model_path.exists():
|
||||||
|
raise FileNotFoundError(f"Tokenizer model file not found: {model_path}")
|
||||||
|
|
||||||
mergeable_ranks = load_tiktoken_bpe(model_path)
|
mergeable_ranks = load_bpe_file(model_path)
|
||||||
num_base_tokens = len(mergeable_ranks)
|
num_base_tokens = len(mergeable_ranks)
|
||||||
|
|
||||||
special_tokens = BASIC_SPECIAL_TOKENS + LLAMA4_SPECIAL_TOKENS
|
special_tokens = BASIC_SPECIAL_TOKENS + LLAMA4_SPECIAL_TOKENS
|
||||||
|
@ -144,7 +145,7 @@ class Tokenizer:
|
||||||
|
|
||||||
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
||||||
self.model = tiktoken.Encoding(
|
self.model = tiktoken.Encoding(
|
||||||
name=Path(model_path).name,
|
name=model_path.name,
|
||||||
pat_str=self.O200K_PATTERN,
|
pat_str=self.O200K_PATTERN,
|
||||||
mergeable_ranks=mergeable_ranks,
|
mergeable_ranks=mergeable_ranks,
|
||||||
special_tokens=self.special_tokens,
|
special_tokens=self.special_tokens,
|
||||||
|
|
40
llama_stack/models/llama/tokenizer_utils.py
Normal file
40
llama_stack/models/llama/tokenizer_utils.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__, "tokenizer_utils")
|
||||||
|
|
||||||
|
|
||||||
|
def load_bpe_file(model_path: Path) -> dict[bytes, int]:
|
||||||
|
"""
|
||||||
|
Load BPE file directly and return mergeable ranks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (Path): Path to the BPE model file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[bytes, int]: Dictionary mapping byte sequences to their ranks.
|
||||||
|
"""
|
||||||
|
mergeable_ranks = {}
|
||||||
|
|
||||||
|
with open(model_path, encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
for line in content.splitlines():
|
||||||
|
if not line.strip(): # Skip empty lines
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
token, rank = line.split()
|
||||||
|
mergeable_ranks[base64.b64decode(token)] = int(rank)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse line '{line}': {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return mergeable_ranks
|
|
@ -15,7 +15,6 @@ from llama_stack.providers.datatypes import (
|
||||||
|
|
||||||
META_REFERENCE_DEPS = [
|
META_REFERENCE_DEPS = [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"blobfile",
|
|
||||||
"fairscale",
|
"fairscale",
|
||||||
"torch",
|
"torch",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
|
|
|
@ -20,7 +20,6 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
api=Api.tool_runtime,
|
api=Api.tool_runtime,
|
||||||
provider_type="inline::rag-runtime",
|
provider_type="inline::rag-runtime",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"blobfile",
|
|
||||||
"chardet",
|
"chardet",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
|
|
177
tests/unit/models/llama/test_tokenizer_utils.py
Normal file
177
tests/unit/models/llama/test_tokenizer_utils.py
Normal file
|
@ -0,0 +1,177 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from tiktoken.load import load_tiktoken_bpe
|
||||||
|
|
||||||
|
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_bpe_content():
|
||||||
|
"""Sample BPE file content for testing."""
|
||||||
|
return """wA== 0
|
||||||
|
wQ== 1
|
||||||
|
9Q== 2
|
||||||
|
9g== 3
|
||||||
|
9w== 4
|
||||||
|
+A== 5
|
||||||
|
+Q== 6
|
||||||
|
+g== 7
|
||||||
|
+w== 8
|
||||||
|
/A== 9
|
||||||
|
/Q== 10
|
||||||
|
/g== 11
|
||||||
|
/w== 12
|
||||||
|
AA== 13
|
||||||
|
AQ== 14"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_bpe_file(tmp_path, test_bpe_content):
|
||||||
|
"""Create a temporary BPE file for testing."""
|
||||||
|
bpe_file = tmp_path / "test_tokenizer.model"
|
||||||
|
bpe_file.write_text(test_bpe_content, encoding="utf-8")
|
||||||
|
return bpe_file
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llama3_model_path():
|
||||||
|
"""Path to Llama3 tokenizer model."""
|
||||||
|
return Path(__file__).parent / "../../../../llama_stack/models/llama/llama3/tokenizer.model"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llama4_model_path():
|
||||||
|
"""Path to Llama4 tokenizer model."""
|
||||||
|
return Path(__file__).parent / "../../../../llama_stack/models/llama/llama4/tokenizer.model"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_bpe_file_basic_functionality(test_bpe_file):
|
||||||
|
"""Test that load_bpe_file correctly parses BPE files."""
|
||||||
|
result = load_bpe_file(test_bpe_file)
|
||||||
|
|
||||||
|
for key, value in result.items():
|
||||||
|
assert isinstance(key, bytes)
|
||||||
|
assert isinstance(value, int)
|
||||||
|
|
||||||
|
assert len(result) == 15
|
||||||
|
|
||||||
|
expected_first_token = base64.b64decode("wA==")
|
||||||
|
assert expected_first_token in result
|
||||||
|
assert result[expected_first_token] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_bpe_file_vs_tiktoken_with_real_model(llama3_model_path):
|
||||||
|
"""Test that our implementation produces identical results to tiktoken on real model files."""
|
||||||
|
if not llama3_model_path.exists():
|
||||||
|
pytest.skip("Llama3 tokenizer model not found")
|
||||||
|
|
||||||
|
our_result = load_bpe_file(llama3_model_path)
|
||||||
|
tiktoken_result = load_tiktoken_bpe(llama3_model_path.as_posix())
|
||||||
|
|
||||||
|
# Compare results from our implementation and tiktoken
|
||||||
|
assert len(our_result) == len(tiktoken_result)
|
||||||
|
assert our_result == tiktoken_result
|
||||||
|
|
||||||
|
assert len(our_result) > 100000
|
||||||
|
ranks = list(our_result.values())
|
||||||
|
assert len(ranks) == len(set(ranks))
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_bpe_file_vs_tiktoken_with_llama4_model(llama4_model_path):
|
||||||
|
"""Test that our implementation produces identical results to tiktoken on Llama4 model."""
|
||||||
|
if not llama4_model_path.exists():
|
||||||
|
pytest.skip("Llama4 tokenizer model not found")
|
||||||
|
|
||||||
|
our_result = load_bpe_file(llama4_model_path)
|
||||||
|
tiktoken_result = load_tiktoken_bpe(llama4_model_path.as_posix())
|
||||||
|
|
||||||
|
# Compare results from our implementation and tiktoken
|
||||||
|
assert len(our_result) == len(tiktoken_result)
|
||||||
|
assert our_result == tiktoken_result
|
||||||
|
|
||||||
|
assert len(our_result) > 100000
|
||||||
|
ranks = list(our_result.values())
|
||||||
|
assert len(ranks) == len(set(ranks))
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_bpe_file_malformed_lines(tmp_path):
|
||||||
|
"""Test that load_bpe_file handles malformed lines gracefully."""
|
||||||
|
malformed_content = """wA== 0
|
||||||
|
invalid_line_without_rank
|
||||||
|
wQ== 1
|
||||||
|
invalid_base64!!! 2
|
||||||
|
9Q== 2"""
|
||||||
|
|
||||||
|
test_file = tmp_path / "malformed.model"
|
||||||
|
test_file.write_text(malformed_content, encoding="utf-8")
|
||||||
|
|
||||||
|
with patch("llama_stack.models.llama.tokenizer_utils.logger") as mock_logger:
|
||||||
|
result = load_bpe_file(test_file)
|
||||||
|
|
||||||
|
# Should have 3 valid entries (skipping malformed ones)
|
||||||
|
assert len(result) == 3
|
||||||
|
|
||||||
|
# Should have logged warnings for malformed lines
|
||||||
|
assert mock_logger.warning.called
|
||||||
|
assert mock_logger.warning.call_count > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_bpe_file_nonexistent_file():
|
||||||
|
"""Test that load_bpe_file raises appropriate error for nonexistent files."""
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
load_bpe_file("/nonexistent/path/to/file.model")
|
||||||
|
|
||||||
|
|
||||||
|
def test_tokenizer_integration():
|
||||||
|
"""Test that our load_bpe_file works correctly when used in actual tokenizers."""
|
||||||
|
try:
|
||||||
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||||
|
|
||||||
|
tokenizer = Llama3Tokenizer.get_instance()
|
||||||
|
|
||||||
|
# Test basic functionality
|
||||||
|
test_text = "Hello, world! This is a test."
|
||||||
|
tokens = tokenizer.encode(test_text, bos=False, eos=False)
|
||||||
|
decoded = tokenizer.decode(tokens)
|
||||||
|
|
||||||
|
assert test_text == decoded
|
||||||
|
assert isinstance(tokens, list)
|
||||||
|
assert all(isinstance(token, int) for token in tokens)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Llama3 tokenizer not available: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_performance_comparison(llama3_model_path):
|
||||||
|
"""Test that our implementation has reasonable performance compared to tiktoken."""
|
||||||
|
if not llama3_model_path.exists():
|
||||||
|
pytest.skip("Llama3 tokenizer model not found")
|
||||||
|
|
||||||
|
# Time our implementation
|
||||||
|
start_time = time.time()
|
||||||
|
our_result = load_bpe_file(llama3_model_path)
|
||||||
|
our_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Time tiktoken implementation
|
||||||
|
start_time = time.time()
|
||||||
|
tiktoken_result = load_tiktoken_bpe(llama3_model_path.as_posix())
|
||||||
|
tiktoken_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Verify results are identical
|
||||||
|
assert our_result == tiktoken_result
|
||||||
|
|
||||||
|
# Our implementation should be reasonably fast (within 10x of tiktoken)
|
||||||
|
# This is a loose bound since we're optimizing for correctness, not speed
|
||||||
|
assert our_time < tiktoken_time * 10, f"Our implementation took {our_time:.3f}s vs tiktoken's {tiktoken_time:.3f}s"
|
||||||
|
|
||||||
|
print(f"Performance comparison - Our: {our_time:.3f}s, Tiktoken: {tiktoken_time:.3f}s")
|
Loading…
Add table
Add a link
Reference in a new issue