forked from phoenix-oss/llama-stack-mirror
## What does this PR do?
In this PR, we refactor the meta reference inference logic to support
- load the model during registering model instead of during spinning up
server
- support inference finetuned model checkpoint on top of native llama
model
## Why need these changes
To solve the existing pain points that
- user cannot lazy load the model and hot switch the inference
checkpoint after spinning up the server
- this blocks us doing inference and eval on the same sever for a
finetuned checkpoint after post training
- user cannot do inference on a finetuned checkpoint on top of native
llama models
## Expect user experience change
- The inference model won't be loaded when spinning up server. Instead,
it will be loaded during register model. If user add the model as models
resource in run.yaml, it will be registered and loaded automatically
when starting server. There is an optional flag 'skip_initialize' in
model metadata to skip model loading during registration.
- There is an optional flag 'llama_model' in model metadata to identify
the base model of the Model class for validation and initialize model
arch. model identifier no longer needs to be a native llama model
- the default inference model name updates from
'meta-llama/Llama-3.2-3B-Instruct' to 'Llama3.2-3B-Instruct'
- It aligns with the checkpoint folder name after running 'llama model
download'
- It aligns with the descriptor name defined in llama-models SKU list
bf5b0c4fe7/models/datatypes.py (L95)
## test
run python llama_stack/scripts/distro_codegen.py
**run unit test**
- torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.1-8B-Instruct"
./llama_stack/providers/tests/inference/test_text_inference.py
- torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference"
--inference-model="Llama3.1-8B-Instruct"
./llama_stack/providers/tests/inference/test_model_registration.py
**test post training experience**
on server side run: llama stack run
llama_stack/templates/experimental-post-training/run.yaml
server is spinning up without model loaded
<img width="812" alt="Screenshot 2024-12-17 at 1 24 50 PM"
src="https://github.com/user-attachments/assets/ce1f606b-3b6f-452f-b48e-b3761ffd90f3"
/>
on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 models register
Llama3.2-3B-Instruct
register model successfully and the model is loaded
<img width="1111" alt="Screenshot 2024-12-17 at 1 26 30 PM"
src="https://github.com/user-attachments/assets/56e02131-cf7d-4de5-8f63-fbdcb8c55c26"
/>
<img width="1541" alt="Screenshot 2024-12-17 at 1 26 09 PM"
src="https://github.com/user-attachments/assets/a83255a1-20f5-40a2-af51-55641410a115"
/>
if add "skip_initialize" in metadata, model is registered but isn't
loaded
on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 inference chat-completion
--message "hello, what model are you?"
Inference the model succesfully
<img width="1121" alt="Screenshot 2024-12-17 at 1 27 33 PM"
src="https://github.com/user-attachments/assets/8e708545-3fe7-4a73-8754-1470fa5f1e75"
/>
**test inference experience**
run: llama stack run llama_stack/templates/meta-reference-gpu/run.yaml
model is loaded since the model is in resouce list in run.yaml
<img width="1537" alt="Screenshot 2024-12-17 at 1 30 19 PM"
src="https://github.com/user-attachments/assets/5c8af817-66eb-43f8-bf4c-f5e24b0a12c6"
/>
on client side, run: llama-stack-client --endpoint
http://devgpu018.nha2.facebook.com:5000 inference chat-completion
--message "hello, what model are you?"
inference successfully
<img width="1123" alt="Screenshot 2024-12-17 at 1 31 08 PM"
src="https://github.com/user-attachments/assets/471809aa-c65e-46dc-a37e-7094fb857f97"
/>
## inference on a finetuned model
**register a finetuned model that finetuned by post training api
(torchtune)**
- the model is registered and loaded successfully
- the model is shown up in the model list
<img width="974" alt="Screenshot 2024-12-18 at 3 56 33 PM"
src="https://github.com/user-attachments/assets/2994b4f5-4fa9-40c6-acc6-4b971479f3e2"
/>
**run inference**
<img width="977" alt="Screenshot 2024-12-18 at 3 57 59 PM"
src="https://github.com/user-attachments/assets/d117abbc-b2a0-41d8-a028-1a13128787b2"
/>
494 lines
18 KiB
Python
494 lines
18 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
|
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Generator, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from fairscale.nn.model_parallel.initialize import (
|
|
get_model_parallel_rank,
|
|
initialize_model_parallel,
|
|
model_parallel_is_initialized,
|
|
)
|
|
from llama_models.llama3.api.args import ModelArgs
|
|
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
|
from llama_models.llama3.api.datatypes import Model
|
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
from llama_models.llama3.reference_impl.model import Transformer
|
|
from llama_models.llama3.reference_impl.multimodal.model import (
|
|
CrossAttentionTransformer,
|
|
)
|
|
from llama_models.sku_list import resolve_model
|
|
from pydantic import BaseModel
|
|
|
|
from llama_stack.apis.inference import * # noqa: F403
|
|
|
|
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
|
|
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
ChatCompletionRequestWithRawContent,
|
|
CompletionRequestWithRawContent,
|
|
)
|
|
|
|
from .config import (
|
|
Fp8QuantizationConfig,
|
|
Int4QuantizationConfig,
|
|
MetaReferenceInferenceConfig,
|
|
MetaReferenceQuantizedInferenceConfig,
|
|
)
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def model_checkpoint_dir(model_id) -> str:
|
|
checkpoint_dir = Path(model_local_dir(model_id))
|
|
|
|
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
|
if not any(p.exists() for p in paths):
|
|
checkpoint_dir = checkpoint_dir / "original"
|
|
|
|
assert checkpoint_dir.exists(), (
|
|
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
|
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
|
|
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
|
|
)
|
|
return str(checkpoint_dir)
|
|
|
|
|
|
class TokenResult(BaseModel):
|
|
token: int
|
|
text: str
|
|
logprobs: Optional[List[float]] = None
|
|
|
|
|
|
class Llama:
|
|
@staticmethod
|
|
def build(
|
|
config: Union[
|
|
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
|
],
|
|
model_id: str,
|
|
llama_model: Model,
|
|
):
|
|
"""
|
|
Build a Llama instance by initializing and loading a model checkpoint.
|
|
|
|
Note:
|
|
This method initializes the distributed process group, sets the device to CUDA,
|
|
and loads the pre-trained model and tokenizer.
|
|
"""
|
|
llama_model_id = llama_model.core_model_id.value
|
|
if not torch.distributed.is_initialized():
|
|
torch.distributed.init_process_group("nccl")
|
|
|
|
model_parallel_size = llama_model.pth_file_count
|
|
|
|
if not model_parallel_is_initialized():
|
|
initialize_model_parallel(model_parallel_size)
|
|
|
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
torch.cuda.set_device(local_rank)
|
|
|
|
# seed must be the same in all processes
|
|
if config.torch_seed is not None:
|
|
torch.manual_seed(config.torch_seed)
|
|
|
|
if local_rank > 0:
|
|
sys.stdout = open(os.devnull, "w")
|
|
|
|
start_time = time.time()
|
|
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
|
ckpt_dir = config.checkpoint_dir
|
|
else:
|
|
resolved_model = resolve_model(model_id)
|
|
if resolved_model is None:
|
|
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
|
ckpt_dir = model_checkpoint_dir(model_id)
|
|
else:
|
|
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
|
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
|
|
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
|
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
|
assert model_parallel_size == len(
|
|
checkpoints
|
|
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
|
ckpt_path = checkpoints[get_model_parallel_rank()]
|
|
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
|
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
|
params = json.loads(f.read())
|
|
|
|
if "model" in params:
|
|
params = params["model"]
|
|
|
|
model_args: ModelArgs = ModelArgs(
|
|
max_seq_len=config.max_seq_len,
|
|
max_batch_size=config.max_batch_size,
|
|
**params,
|
|
)
|
|
|
|
tokenizer = Tokenizer.get_instance()
|
|
assert (
|
|
model_args.vocab_size == tokenizer.n_words
|
|
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
|
|
|
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
|
if isinstance(config.quantization, Fp8QuantizationConfig):
|
|
from .quantization.loader import convert_to_fp8_quantized_model
|
|
|
|
# load on CPU in bf16 so that fp8 conversion does not find an
|
|
# unexpected (fp32, e.g.) datatype
|
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
|
if model_args.vision_chunk_size > 0:
|
|
model = CrossAttentionTransformer(model_args)
|
|
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
|
else:
|
|
model = Transformer(model_args)
|
|
model.load_state_dict(state_dict, strict=False)
|
|
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
|
|
elif isinstance(config.quantization, Int4QuantizationConfig):
|
|
from .quantization.loader import convert_to_int4_quantized_model
|
|
|
|
model = Transformer(model_args)
|
|
model = convert_to_int4_quantized_model(model, model_args, config)
|
|
model.load_state_dict(state_dict, strict=True)
|
|
|
|
if (
|
|
model_args.quantization_args is not None
|
|
and model_args.quantization_args.spinquant
|
|
):
|
|
# Add a wrapper for adding hadamard transform for spinquant.
|
|
# This needs to be done after loading the state dict otherwise an error will be raised while
|
|
# loading the state dict.
|
|
from .quantization.hadamard_utils import (
|
|
add_hadamard_transform_for_spinquant,
|
|
)
|
|
|
|
add_hadamard_transform_for_spinquant(model)
|
|
else:
|
|
raise NotImplementedError(
|
|
"Currently int4 and fp8 are the only supported quantization methods."
|
|
)
|
|
else:
|
|
if torch.cuda.is_bf16_supported():
|
|
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
|
else:
|
|
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
|
if model_args.vision_chunk_size > 0:
|
|
model = CrossAttentionTransformer(model_args)
|
|
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
|
else:
|
|
model = Transformer(model_args)
|
|
model.load_state_dict(state_dict, strict=False)
|
|
|
|
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
|
|
return Llama(model, tokenizer, model_args, llama_model_id)
|
|
|
|
def __init__(
|
|
self,
|
|
model: Transformer,
|
|
tokenizer: Tokenizer,
|
|
args: ModelArgs,
|
|
llama_model: str,
|
|
):
|
|
self.args = args
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.formatter = ChatFormat(tokenizer)
|
|
self.llama_model = llama_model
|
|
|
|
@torch.inference_mode()
|
|
def generate(
|
|
self,
|
|
model_input: LLMInput,
|
|
max_gen_len: int,
|
|
temperature: float = 0.6,
|
|
top_p: float = 0.9,
|
|
logprobs: bool = False,
|
|
echo: bool = False,
|
|
include_stop_token: bool = False,
|
|
print_input_tokens: bool = False,
|
|
logits_processor: Optional["LogitsProcessor"] = None,
|
|
) -> Generator:
|
|
params = self.model.params
|
|
|
|
if print_input_tokens:
|
|
input_tokens = [
|
|
self.formatter.vision_token if t == 128256 else t
|
|
for t in model_input.tokens
|
|
]
|
|
log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
|
|
prompt_tokens = [model_input.tokens]
|
|
|
|
bsz = 1
|
|
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
|
|
|
min_prompt_len = min(len(t) for t in prompt_tokens)
|
|
max_prompt_len = max(len(t) for t in prompt_tokens)
|
|
|
|
if max_prompt_len >= params.max_seq_len:
|
|
log.error(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}")
|
|
return
|
|
|
|
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
|
|
|
is_vision = isinstance(self.model, CrossAttentionTransformer)
|
|
if is_vision:
|
|
images = model_input.vision.images if model_input.vision is not None else []
|
|
mask = model_input.vision.mask if model_input.vision is not None else []
|
|
|
|
# the method works for bsz > 1 so add a batch dimension
|
|
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = (
|
|
self.model.compute_vision_tokens_masks(
|
|
batch_images=[images],
|
|
batch_masks=[mask],
|
|
total_len=total_len,
|
|
)
|
|
)
|
|
|
|
pad_id = self.tokenizer.pad_id
|
|
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
|
|
for k, t in enumerate(prompt_tokens):
|
|
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
|
if logprobs:
|
|
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
|
|
|
prev_pos = 0
|
|
eos_reached = torch.tensor([False] * bsz, device="cuda")
|
|
input_text_mask = tokens != pad_id
|
|
if min_prompt_len == total_len:
|
|
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
|
|
logits = self.model.forward(tokens, prev_pos)
|
|
token_logprobs = -F.cross_entropy(
|
|
input=logits.transpose(1, 2),
|
|
target=tokens,
|
|
reduction="none",
|
|
ignore_index=pad_id,
|
|
)
|
|
|
|
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
|
|
for cur_pos in range(min_prompt_len, total_len):
|
|
if is_vision:
|
|
position_ids = torch.arange(
|
|
prev_pos, cur_pos, dtype=torch.long, device="cuda"
|
|
)
|
|
logits = self.model.forward(
|
|
position_ids,
|
|
tokens,
|
|
cross_attention_masks,
|
|
full_text_row_masked_out_mask,
|
|
xattn_caches,
|
|
)
|
|
else:
|
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
|
|
|
if logits_processor is not None:
|
|
logits = logits_processor.process_logits(tokens[:, :cur_pos], logits)
|
|
|
|
if temperature > 0:
|
|
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
|
next_token = sample_top_p(probs, top_p)
|
|
else:
|
|
next_token = torch.argmax(logits[:, -1], dim=-1)
|
|
|
|
next_token = next_token.reshape(-1)
|
|
# only replace token if prompt has already been generated
|
|
next_token = torch.where(
|
|
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
|
|
)
|
|
tokens[:, cur_pos] = next_token
|
|
|
|
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
|
if is_vision:
|
|
# the logits space (num_classes) is designed to never contain a media_token
|
|
# however our input token stream does contain them. we need to nuke them here
|
|
# or else the CUDA kernels will crash with an illegal memory access
|
|
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
|
|
masks = [target.eq(t) for t in vision_tokens]
|
|
if len(masks) > 1:
|
|
mask = torch.logical_or(*masks)
|
|
else:
|
|
mask = masks[0]
|
|
target[mask] = 0
|
|
|
|
if logprobs:
|
|
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
|
input=logits.transpose(1, 2),
|
|
target=tokens[:, prev_pos + 1 : cur_pos + 1],
|
|
reduction="none",
|
|
ignore_index=pad_id,
|
|
)
|
|
eos_reached |= (~input_text_mask[:, cur_pos]) & (
|
|
torch.isin(next_token, stop_tokens)
|
|
)
|
|
yield TokenResult(
|
|
token=next_token[0].item(),
|
|
text=self.tokenizer.decode(next_token.tolist()),
|
|
logprobs=(
|
|
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
|
|
if logprobs
|
|
else None
|
|
),
|
|
)
|
|
|
|
prev_pos = cur_pos
|
|
if all(eos_reached):
|
|
break
|
|
|
|
def completion(
|
|
self,
|
|
request: CompletionRequestWithRawContent,
|
|
) -> Generator:
|
|
sampling_params = request.sampling_params
|
|
max_gen_len = sampling_params.max_tokens
|
|
if (
|
|
max_gen_len is None
|
|
or max_gen_len == 0
|
|
or max_gen_len >= self.model.params.max_seq_len
|
|
):
|
|
max_gen_len = self.model.params.max_seq_len - 1
|
|
|
|
model_input = self.formatter.encode_content(request.content)
|
|
yield from self.generate(
|
|
model_input=model_input,
|
|
max_gen_len=max_gen_len,
|
|
temperature=sampling_params.temperature,
|
|
top_p=sampling_params.top_p,
|
|
logprobs=bool(request.logprobs),
|
|
include_stop_token=True,
|
|
logits_processor=get_logits_processor(
|
|
self.tokenizer,
|
|
self.args.vocab_size,
|
|
request.response_format,
|
|
),
|
|
)
|
|
|
|
def chat_completion(
|
|
self,
|
|
request: ChatCompletionRequestWithRawContent,
|
|
) -> Generator:
|
|
sampling_params = request.sampling_params
|
|
max_gen_len = sampling_params.max_tokens
|
|
if (
|
|
max_gen_len is None
|
|
or max_gen_len == 0
|
|
or max_gen_len >= self.model.params.max_seq_len
|
|
):
|
|
max_gen_len = self.model.params.max_seq_len - 1
|
|
|
|
yield from self.generate(
|
|
model_input=self.formatter.encode_dialog_prompt(
|
|
request.messages,
|
|
request.tool_prompt_format,
|
|
),
|
|
max_gen_len=max_gen_len,
|
|
temperature=sampling_params.temperature,
|
|
top_p=sampling_params.top_p,
|
|
logprobs=bool(request.logprobs),
|
|
include_stop_token=True,
|
|
logits_processor=get_logits_processor(
|
|
self.tokenizer,
|
|
self.args.vocab_size,
|
|
request.response_format,
|
|
),
|
|
)
|
|
|
|
|
|
def sample_top_p(probs, p):
|
|
"""
|
|
Perform top-p (nucleus) sampling on a probability distribution.
|
|
|
|
Args:
|
|
probs (torch.Tensor): Probability distribution tensor.
|
|
p (float): Probability threshold for top-p sampling.
|
|
|
|
Returns:
|
|
torch.Tensor: Sampled token indices.
|
|
|
|
Note:
|
|
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
|
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
|
"""
|
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
|
mask = probs_sum - probs_sort > p
|
|
probs_sort[mask] = 0.0
|
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
|
next_token = torch.multinomial(probs_sort, num_samples=1)
|
|
next_token = torch.gather(probs_idx, -1, next_token)
|
|
return next_token
|
|
|
|
|
|
class LogitsProcessor:
|
|
def __init__(self, token_enforcer: TokenEnforcer):
|
|
self.token_enforcer = token_enforcer
|
|
self.mask: Optional[torch.Tensor] = None
|
|
|
|
def process_logits(
|
|
self, tokens: torch.Tensor, scores: torch.Tensor
|
|
) -> torch.Tensor:
|
|
token_sequence = tokens[0, :].tolist()
|
|
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
|
|
|
if self.mask is not None:
|
|
self.mask.fill_(-math.inf)
|
|
else:
|
|
self.mask = torch.full_like(scores, -math.inf)
|
|
|
|
self.mask[:, :, allowed_tokens] = 0
|
|
scores = scores + self.mask
|
|
return scores
|
|
|
|
|
|
def get_logits_processor(
|
|
tokenizer: Tokenizer,
|
|
vocab_size: int,
|
|
response_format: Optional[ResponseFormat],
|
|
) -> Optional["LogitsProcessor"]:
|
|
if response_format is None:
|
|
return None
|
|
|
|
if response_format.type != ResponseFormatType.json_schema.value:
|
|
raise ValueError(f"Unsupported response format type {response_format.type}")
|
|
|
|
parser = JsonSchemaParser(response_format.json_schema)
|
|
data = TokenEnforcerTokenizerData(
|
|
_build_regular_tokens_list(tokenizer, vocab_size),
|
|
tokenizer.decode,
|
|
tokenizer.stop_tokens,
|
|
)
|
|
token_enforcer = TokenEnforcer(data, parser)
|
|
return LogitsProcessor(token_enforcer)
|
|
|
|
|
|
def _build_regular_tokens_list(
|
|
tokenizer: Tokenizer, vocab_size: int
|
|
) -> List[Tuple[int, str, bool]]:
|
|
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
|
|
regular_tokens = []
|
|
|
|
special_token_ids = set(tokenizer.special_tokens.values())
|
|
for token_idx in range(vocab_size):
|
|
if token_idx in special_token_ids:
|
|
continue
|
|
|
|
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
|
|
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
|
|
decoded_regular = tokenizer.decode([token_idx])
|
|
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
|
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
|
return regular_tokens
|