mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
# What does this PR do? Delete ~2,000 lines of dead code from the old bespoke inference API that was replaced by OpenAI-only API. This includes removing unused type conversion functions, dead provider methods, and event_logger.py. Clean up imports across the codebase to remove references to deleted types. This eliminates unnecessary code and dependencies, helping isolate the API package as a self-contained module. This is the last interdependency between the .api package and "exterior" packages, meaning that now every other package in llama stack imports the API, not the other way around. ## Test Plan this is a structural change, no tests needed. --------- Signed-off-by: Charlie Doern <cdoern@redhat.com>
378 lines
14 KiB
Python
378 lines
14 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# top-level folder for each specific model found within the models/ directory at
|
|
# the top-level of this source tree.
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from collections.abc import Callable, Generator
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from fairscale.nn.model_parallel.initialize import (
|
|
initialize_model_parallel,
|
|
model_parallel_is_initialized,
|
|
)
|
|
from termcolor import cprint
|
|
|
|
from llama_stack.models.llama.datatypes import ToolPromptFormat
|
|
|
|
from ..checkpoint import maybe_reshard_state_dict
|
|
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage
|
|
from .args import ModelArgs
|
|
from .chat_format import ChatFormat, LLMInput
|
|
from .model import Transformer
|
|
from .multimodal.model import CrossAttentionTransformer
|
|
from .tokenizer import Tokenizer
|
|
|
|
|
|
class Llama3:
|
|
@staticmethod
|
|
def build(
|
|
ckpt_dir: str,
|
|
max_seq_len: int,
|
|
max_batch_size: int,
|
|
world_size: int | None = None,
|
|
quantization_mode: QuantizationMode | None = None,
|
|
seed: int = 1,
|
|
device: str = "cuda",
|
|
):
|
|
device = torch.device(device)
|
|
if (
|
|
device.type == "cuda"
|
|
and not torch.cuda.is_available()
|
|
or device.type == "xpu"
|
|
and not torch.xpu.is_available()
|
|
):
|
|
raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")
|
|
|
|
if not torch.distributed.is_initialized():
|
|
if device.type == "cuda":
|
|
torch.distributed.init_process_group("nccl")
|
|
else:
|
|
torch.distributed.init_process_group("gloo")
|
|
|
|
if not model_parallel_is_initialized():
|
|
if world_size is None:
|
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
initialize_model_parallel(world_size)
|
|
|
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
if device.type == "cuda":
|
|
torch.cuda.set_device(local_rank)
|
|
elif device.type == "xpu":
|
|
torch.xpu.set_device(local_rank)
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
if local_rank > 0:
|
|
sys.stdout = open(os.devnull, "w")
|
|
|
|
start_time = time.time()
|
|
|
|
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
|
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
|
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
|
with open(Path(ckpt_dir) / "params.json") as f:
|
|
params = json.loads(f.read())
|
|
|
|
model_args: ModelArgs = ModelArgs(
|
|
max_seq_len=max_seq_len,
|
|
max_batch_size=max_batch_size,
|
|
**params,
|
|
)
|
|
tokenizer = Tokenizer.get_instance()
|
|
|
|
state_dict = maybe_reshard_state_dict(
|
|
ckpt_paths,
|
|
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
|
|
)
|
|
|
|
assert model_args.vocab_size == tokenizer.n_words
|
|
|
|
def build_model():
|
|
if model_args.vision_chunk_size > 0:
|
|
model = CrossAttentionTransformer(model_args)
|
|
model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
|
|
else:
|
|
model = Transformer(model_args)
|
|
return model
|
|
|
|
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
|
|
from .quantization.loader import convert_to_quantized_model
|
|
|
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
|
model = build_model()
|
|
print("Loading state dict...")
|
|
model.load_state_dict(state_dict, strict=False)
|
|
print("Done...")
|
|
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
|
|
torch.set_default_device(device)
|
|
else:
|
|
print(f"Setting default device to {device}")
|
|
if device.type == "cuda":
|
|
if torch.cuda.is_bf16_supported():
|
|
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
|
else:
|
|
torch.set_default_tensor_type(torch.cuda.Float16Tensor)
|
|
elif device.type == "xpu":
|
|
if torch.xpu.is_bf16_supported():
|
|
torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
|
|
else:
|
|
torch.set_default_tensor_type(torch.xpu.Float16Tensor)
|
|
|
|
model = build_model()
|
|
print("Loading state dict...")
|
|
model.load_state_dict(state_dict, strict=True)
|
|
model.to(device)
|
|
print("Done...")
|
|
|
|
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
|
|
|
return Llama3(model, tokenizer, model_args)
|
|
|
|
def __init__(
|
|
self,
|
|
model: Transformer | CrossAttentionTransformer,
|
|
tokenizer: Tokenizer,
|
|
args: ModelArgs,
|
|
):
|
|
self.args = args
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.formatter = ChatFormat(tokenizer)
|
|
|
|
@torch.inference_mode()
|
|
def generate(
|
|
self,
|
|
llm_inputs: list[LLMInput],
|
|
temperature: float = 0.6,
|
|
top_p: float = 0.9,
|
|
max_gen_len: int | None = None,
|
|
logprobs: bool = False,
|
|
echo: bool = False,
|
|
print_model_input: bool = False,
|
|
logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
|
) -> Generator[list[GenerationResult], None, None]:
|
|
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
|
max_gen_len = self.args.max_seq_len - 1
|
|
params = self.model.params
|
|
|
|
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
|
if print_model_input:
|
|
for inp in llm_inputs:
|
|
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
|
|
cprint(
|
|
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
|
"red",
|
|
file=sys.stderr,
|
|
)
|
|
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
|
|
|
bsz = len(llm_inputs)
|
|
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
|
|
|
min_prompt_len = min(len(t) for t in prompt_tokens)
|
|
max_prompt_len = max(len(t) for t in prompt_tokens)
|
|
|
|
if max_prompt_len >= params.max_seq_len:
|
|
cprint(
|
|
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}",
|
|
color="red",
|
|
file=sys.stderr,
|
|
)
|
|
return
|
|
|
|
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
|
|
|
pad_id = self.tokenizer.pad_id
|
|
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
|
|
for k, t in enumerate(prompt_tokens):
|
|
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
|
|
if logprobs:
|
|
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
|
|
|
is_vision = not isinstance(self.model, Transformer)
|
|
if is_vision:
|
|
images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
|
|
mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
|
|
|
|
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
|
batch_images=images,
|
|
batch_masks=mask,
|
|
total_len=total_len,
|
|
device=tokens.device,
|
|
)
|
|
|
|
eos_reached = torch.tensor([False] * bsz)
|
|
input_text_mask = tokens != pad_id
|
|
|
|
if echo:
|
|
for i in range(max_prompt_len):
|
|
results = []
|
|
for j, t in enumerate(tokens[:, i]):
|
|
results.append(
|
|
GenerationResult(
|
|
token=t.item(),
|
|
text=self.tokenizer.decode([t.item()]),
|
|
source="input",
|
|
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
|
|
batch_idx=j,
|
|
finished=False,
|
|
ignore_token=t.item() == pad_id,
|
|
)
|
|
)
|
|
yield results
|
|
|
|
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
|
|
|
prev_pos = 0
|
|
for cur_pos in range(min_prompt_len, total_len):
|
|
if is_vision:
|
|
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
|
text_only_inference = all(inp.vision is None for inp in llm_inputs)
|
|
logits = self.model.forward(
|
|
position_ids,
|
|
tokens,
|
|
cross_attention_masks,
|
|
full_text_row_masked_out_mask,
|
|
xattn_caches,
|
|
text_only_inference,
|
|
)
|
|
else:
|
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
|
|
|
if logits_processor is not None:
|
|
logits = logits_processor(tokens[:, :cur_pos], logits)
|
|
|
|
if temperature > 0:
|
|
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
|
next_token = sample_top_p(probs, top_p)
|
|
else:
|
|
next_token = torch.argmax(logits[:, -1], dim=-1)
|
|
|
|
next_token = next_token.reshape(-1)
|
|
# only replace token if prompt has already been generated
|
|
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
|
tokens[:, cur_pos] = next_token
|
|
|
|
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
|
if is_vision:
|
|
# the logits space (num_classes) is designed to never contain a media_token
|
|
# however our input token stream does contain them. we need to nuke them here
|
|
# or else the CUDA kernels will crash with an illegal memory access
|
|
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
|
|
masks = [target.eq(t) for t in vision_tokens]
|
|
if len(masks) > 1:
|
|
mask = torch.logical_or(*masks)
|
|
else:
|
|
mask = masks[0]
|
|
target[mask] = 0
|
|
|
|
if logprobs:
|
|
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
|
input=logits.transpose(1, 2),
|
|
target=target,
|
|
reduction="none",
|
|
ignore_index=pad_id,
|
|
)
|
|
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
|
results = []
|
|
for idx, t in enumerate(next_token):
|
|
results.append(
|
|
GenerationResult(
|
|
token=t.item(),
|
|
text=self.tokenizer.decode([t.item()]),
|
|
source="output",
|
|
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
|
batch_idx=idx,
|
|
finished=eos_reached[idx].item(),
|
|
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
|
)
|
|
)
|
|
yield results
|
|
|
|
prev_pos = cur_pos
|
|
if all(eos_reached):
|
|
break
|
|
|
|
def completion(
|
|
self,
|
|
contents: list[RawContent],
|
|
temperature: float = 0.6,
|
|
top_p: float = 0.9,
|
|
max_gen_len: int | None = None,
|
|
logprobs: bool = False,
|
|
echo: bool = False,
|
|
) -> Generator[list[GenerationResult], None, None]:
|
|
model_inputs = [self.formatter.encode_content(c) for c in contents]
|
|
for result in self.generate(
|
|
model_inputs=model_inputs,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
max_gen_len=max_gen_len,
|
|
logprobs=logprobs,
|
|
echo=echo,
|
|
):
|
|
yield result
|
|
if all(r.finished for r in result):
|
|
break
|
|
|
|
def chat_completion(
|
|
self,
|
|
messages_batch: list[list[RawMessage]],
|
|
temperature: float = 0.6,
|
|
top_p: float = 0.9,
|
|
max_gen_len: int | None = None,
|
|
logprobs: bool = False,
|
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
|
echo: bool = False,
|
|
) -> Generator[list[GenerationResult], None, None]:
|
|
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
|
for result in self.generate(
|
|
model_inputs=model_inputs,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
max_gen_len=max_gen_len,
|
|
logprobs=logprobs,
|
|
echo=echo,
|
|
):
|
|
yield result
|
|
if all(r.finished for r in result):
|
|
break
|
|
|
|
|
|
def sample_top_p(probs, p):
|
|
"""
|
|
Perform top-p (nucleus) sampling on a probability distribution.
|
|
|
|
Args:
|
|
probs (torch.Tensor): Probability distribution tensor.
|
|
p (float): Probability threshold for top-p sampling.
|
|
|
|
Returns:
|
|
torch.Tensor: Sampled token indices.
|
|
|
|
Note:
|
|
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
|
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
|
"""
|
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
|
mask = probs_sum - probs_sort > p
|
|
probs_sort[mask] = 0.0
|
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
|
next_token = torch.multinomial(probs_sort, num_samples=1)
|
|
next_token = torch.gather(probs_idx, -1, next_token)
|
|
return next_token
|