chore: remove usage of load_tiktoken_bpe (#2276)

This commit is contained in:
Sébastien Han 2025-06-02 16:33:37 +02:00 committed by GitHub
parent af65207ebd
commit 1c0c6e1e17
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 234 additions and 17 deletions

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger
from pathlib import Path
@ -14,7 +13,8 @@ from typing import (
)
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
logger = getLogger(__name__)
@ -48,19 +48,20 @@ class Tokenizer:
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
_INSTANCE = Tokenizer(Path(__file__).parent / "tokenizer.model")
return _INSTANCE
def __init__(self, model_path: str):
def __init__(self, model_path: Path):
"""
Initializes the Tokenizer with a Tiktoken model.
Args:
model_path (str): The path to the Tiktoken model file.
"""
assert os.path.isfile(model_path), model_path
if not model_path.exists():
raise FileNotFoundError(f"Tokenizer model file not found: {model_path}")
mergeable_ranks = load_tiktoken_bpe(model_path)
mergeable_ranks = load_bpe_file(model_path)
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
@ -83,7 +84,7 @@ class Tokenizer:
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
name=model_path.name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger
from pathlib import Path
@ -14,7 +13,8 @@ from typing import (
)
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
logger = getLogger(__name__)
@ -118,19 +118,20 @@ class Tokenizer:
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
_INSTANCE = Tokenizer(Path(__file__).parent / "tokenizer.model")
return _INSTANCE
def __init__(self, model_path: str):
def __init__(self, model_path: Path):
"""
Initializes the Tokenizer with a Tiktoken model.
Args:
model_path (str): The path to the Tiktoken model file.
model_path (Path): The path to the Tiktoken model file.
"""
assert os.path.isfile(model_path), model_path
if not model_path.exists():
raise FileNotFoundError(f"Tokenizer model file not found: {model_path}")
mergeable_ranks = load_tiktoken_bpe(model_path)
mergeable_ranks = load_bpe_file(model_path)
num_base_tokens = len(mergeable_ranks)
special_tokens = BASIC_SPECIAL_TOKENS + LLAMA4_SPECIAL_TOKENS
@ -144,7 +145,7 @@ class Tokenizer:
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
name=model_path.name,
pat_str=self.O200K_PATTERN,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
from pathlib import Path
from llama_stack.log import get_logger
logger = get_logger(__name__, "tokenizer_utils")
def load_bpe_file(model_path: Path) -> dict[bytes, int]:
"""
Load BPE file directly and return mergeable ranks.
Args:
model_path (Path): Path to the BPE model file.
Returns:
dict[bytes, int]: Dictionary mapping byte sequences to their ranks.
"""
mergeable_ranks = {}
with open(model_path, encoding="utf-8") as f:
content = f.read()
for line in content.splitlines():
if not line.strip(): # Skip empty lines
continue
try:
token, rank = line.split()
mergeable_ranks[base64.b64decode(token)] = int(rank)
except Exception as e:
logger.warning(f"Failed to parse line '{line}': {e}")
continue
return mergeable_ranks

View file

@ -15,7 +15,6 @@ from llama_stack.providers.datatypes import (
META_REFERENCE_DEPS = [
"accelerate",
"blobfile",
"fairscale",
"torch",
"torchvision",

View file

@ -20,7 +20,6 @@ def available_providers() -> list[ProviderSpec]:
api=Api.tool_runtime,
provider_type="inline::rag-runtime",
pip_packages=[
"blobfile",
"chardet",
"pypdf",
"tqdm",