mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat: introduce llama4 support (#1877)
As title says. Details in README, elsewhere.
This commit is contained in:
parent
23a99a4b22
commit
b8f1561956
61 changed files with 205222 additions and 6439 deletions
255
llama_stack/models/llama/llama4/tokenizer.py
Normal file
255
llama_stack/models/llama/llama4/tokenizer.py
Normal file
|
@ -0,0 +1,255 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
# The tiktoken tokenizer can handle <=400k chars without
|
||||
# pyo3_runtime.PanicException.
|
||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||
|
||||
# https://github.com/openai/tiktoken/issues/195
|
||||
# Here we iterate over subsequences and split if we exceed the limit
|
||||
# of max consecutive non-whitespace or whitespace characters.
|
||||
MAX_NO_WHITESPACES_CHARS = 25_000
|
||||
|
||||
|
||||
_INSTANCE = None
|
||||
|
||||
|
||||
def get_reserved_special_tokens(name, count, start_index=0):
|
||||
return [f"<|{name}_reserved_special_token_{i}|>" for i in range(start_index, start_index + count)]
|
||||
|
||||
|
||||
# 200005, ..., 200079
|
||||
LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
|
||||
"<|header_start|>",
|
||||
"<|header_end|>",
|
||||
"<|eom|>",
|
||||
"<|eot|>",
|
||||
"<|step|>",
|
||||
"<|text_post_train_reserved_special_token_0|>",
|
||||
"<|text_post_train_reserved_special_token_1|>",
|
||||
"<|text_post_train_reserved_special_token_2|>",
|
||||
"<|text_post_train_reserved_special_token_3|>",
|
||||
"<|text_post_train_reserved_special_token_4|>",
|
||||
"<|text_post_train_reserved_special_token_5|>",
|
||||
"<|python_start|>",
|
||||
"<|python_end|>",
|
||||
"<|finetune_right_pad|>",
|
||||
] + get_reserved_special_tokens(
|
||||
"text_post_train", 61, 6
|
||||
) # <|text_post_train_reserved_special_token_6|>, ..., <|text_post_train_reserved_special_token_66|>
|
||||
|
||||
# 200080, ..., 201133
|
||||
LLAMA4_VISION_SPECIAL_TOKENS = [
|
||||
"<|image_start|>",
|
||||
"<|image_end|>",
|
||||
"<|vision_reserved_special_token_0|>",
|
||||
"<|vision_reserved_special_token_1|>",
|
||||
"<|tile_x_separator|>",
|
||||
"<|tile_y_separator|>",
|
||||
"<|vision_reserved_special_token_2|>",
|
||||
"<|vision_reserved_special_token_3|>",
|
||||
"<|vision_reserved_special_token_4|>",
|
||||
"<|vision_reserved_special_token_5|>",
|
||||
"<|image|>",
|
||||
"<|vision_reserved_special_token_6|>",
|
||||
"<|patch|>",
|
||||
] + get_reserved_special_tokens(
|
||||
"vision", 1041, 7
|
||||
) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|>
|
||||
|
||||
|
||||
LLAMA4_SPECIAL_TOKENS = LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS
|
||||
|
||||
BASIC_SPECIAL_TOKENS = [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
"<|fim_prefix|>",
|
||||
"<|fim_middle|>",
|
||||
"<|fim_suffix|>",
|
||||
]
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
||||
"""
|
||||
|
||||
special_tokens: Dict[str, int]
|
||||
|
||||
num_reserved_special_tokens = 2048
|
||||
|
||||
O200K_PATTERN = r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
global _INSTANCE
|
||||
|
||||
if _INSTANCE is None:
|
||||
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
|
||||
return _INSTANCE
|
||||
|
||||
def __init__(self, model_path: str):
|
||||
"""
|
||||
Initializes the Tokenizer with a Tiktoken model.
|
||||
|
||||
Args:
|
||||
model_path (str): The path to the Tiktoken model file.
|
||||
"""
|
||||
assert os.path.isfile(model_path), model_path
|
||||
|
||||
mergeable_ranks = load_tiktoken_bpe(model_path)
|
||||
num_base_tokens = len(mergeable_ranks)
|
||||
|
||||
special_tokens = BASIC_SPECIAL_TOKENS + LLAMA4_SPECIAL_TOKENS
|
||||
assert len(set(special_tokens)) == len(special_tokens)
|
||||
assert len(special_tokens) <= self.num_reserved_special_tokens
|
||||
|
||||
reserved_tokens = [
|
||||
f"<|reserved_special_token_{i}|>" for i in range(self.num_reserved_special_tokens - len(special_tokens))
|
||||
]
|
||||
special_tokens = special_tokens + reserved_tokens
|
||||
|
||||
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
||||
self.model = tiktoken.Encoding(
|
||||
name=Path(model_path).name,
|
||||
pat_str=self.O200K_PATTERN,
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=self.special_tokens,
|
||||
)
|
||||
|
||||
self.n_words: int = num_base_tokens + len(special_tokens)
|
||||
|
||||
# BOS / EOS token IDs
|
||||
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
|
||||
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
|
||||
|
||||
self.pad_id: int = self.special_tokens["<|finetune_right_pad|>"]
|
||||
self.eot_id: int = self.special_tokens["<|eot|>"]
|
||||
self.eom_id: int = self.special_tokens["<|eom|>"]
|
||||
|
||||
self.stop_tokens = [
|
||||
self.eos_id,
|
||||
self.special_tokens["<|eom|>"],
|
||||
self.special_tokens["<|eot|>"],
|
||||
]
|
||||
|
||||
def encode(
|
||||
self,
|
||||
s: str,
|
||||
*,
|
||||
bos: bool,
|
||||
eos: bool,
|
||||
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = (),
|
||||
) -> List[int]:
|
||||
"""
|
||||
Encodes a string into a list of token IDs.
|
||||
|
||||
Args:
|
||||
s (str): The input string to be encoded.
|
||||
bos (bool): Whether to prepend the beginning-of-sequence token.
|
||||
eos (bool): Whether to append the end-of-sequence token.
|
||||
allowed_special ("all"|set[str]): allowed special tokens in string
|
||||
disallowed_special ("all"|set[str]): special tokens that raise an error when in string
|
||||
|
||||
Returns:
|
||||
list[int]: A list of token IDs.
|
||||
|
||||
By default, setting disallowed_special=() encodes a string by ignoring
|
||||
special tokens. Specifically:
|
||||
- Setting `disallowed_special` to () will cause all text corresponding
|
||||
to special tokens to be encoded as natural text (insteading of raising
|
||||
an error).
|
||||
- Setting `allowed_special` to "all" will treat all text corresponding
|
||||
to special tokens to be encoded as special tokens.
|
||||
"""
|
||||
if allowed_special is None:
|
||||
allowed_special = set()
|
||||
assert type(s) is str
|
||||
|
||||
substrs = (
|
||||
substr
|
||||
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
|
||||
for substr in self._split_whitespaces_or_nonwhitespaces(
|
||||
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||
)
|
||||
)
|
||||
t: List[int] = []
|
||||
for substr in substrs:
|
||||
t.extend(
|
||||
self.model.encode(
|
||||
substr,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
)
|
||||
)
|
||||
if bos:
|
||||
t.insert(0, self.bos_id)
|
||||
if eos:
|
||||
t.append(self.eos_id)
|
||||
return t
|
||||
|
||||
def decode(self, t: Sequence[int]) -> str:
|
||||
"""
|
||||
Decodes a list of token IDs into a string.
|
||||
|
||||
Args:
|
||||
t (List[int]): The list of token IDs to be decoded.
|
||||
|
||||
Returns:
|
||||
str: The decoded string.
|
||||
"""
|
||||
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
||||
return self.model.decode(cast(List[int], t))
|
||||
|
||||
@staticmethod
|
||||
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
||||
"""
|
||||
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
||||
consecutive whitespaces or consecutive non-whitespaces.
|
||||
"""
|
||||
current_slice_len = 0
|
||||
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
||||
slice_start = 0
|
||||
|
||||
for i in range(len(s)):
|
||||
is_now_space = s[i].isspace()
|
||||
|
||||
if current_slice_is_space ^ is_now_space:
|
||||
current_slice_len = 1
|
||||
current_slice_is_space = is_now_space
|
||||
else:
|
||||
current_slice_len += 1
|
||||
if current_slice_len > max_consecutive_slice_len:
|
||||
yield s[slice_start:i]
|
||||
slice_start = i
|
||||
current_slice_len = 1
|
||||
yield s[slice_start:]
|
Loading…
Add table
Add a link
Reference in a new issue