diff --git a/llama_stack/providers/inline/inference/meta_reference/hadamard_utils.py b/llama_stack/models/llama/hadamard_utils.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/hadamard_utils.py rename to llama_stack/models/llama/hadamard_utils.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/args.py b/llama_stack/models/llama/llama3/args.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama3/args.py rename to llama_stack/models/llama/llama3/args.py diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py index 2862f8558..8ae911fc3 100644 --- a/llama_stack/models/llama/llama3/chat_format.py +++ b/llama_stack/models/llama/llama3/chat_format.py @@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Tuple from PIL import Image as PIL_Image -from llama_stack.models.llama.datatypes import ( +from ..datatypes import ( BuiltinTool, RawContent, RawMediaItem, @@ -30,7 +30,6 @@ from llama_stack.models.llama.datatypes import ( ToolCall, ToolPromptFormat, ) - from .tokenizer import Tokenizer from .tool_utils import ToolUtils diff --git a/llama_stack/models/llama/llama3/generation.py b/llama_stack/models/llama/llama3/generation.py new file mode 100644 index 000000000..b4e0d39b9 --- /dev/null +++ b/llama_stack/models/llama/llama3/generation.py @@ -0,0 +1,447 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + +import json +import os +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Generator, List, Optional + +import torch +import torch.nn.functional as F +from fairscale.nn.model_parallel.initialize import ( + get_model_parallel_rank, + initialize_model_parallel, + model_parallel_is_initialized, +) +from termcolor import cprint + +from ..datatypes import RawContent, RawMessage, StopReason, ToolPromptFormat +from .args import ModelArgs +from .chat_format import ChatFormat, LLMInput +from .model import Transformer +from .tokenizer import Tokenizer + + +@dataclass +class CompletionPrediction: + generation: str + decoded_tokens: Optional[List[str]] = None + logprobs: Optional[List[List[float]]] = None + + +@dataclass +class ChatPrediction: + generation: RawMessage + decoded_tokens: Optional[List[str]] = None + logprobs: Optional[List[List[float]]] = None + + +@dataclass +class TokenResult: + token: int + text: str + logprobs: Optional[List[float]] = None + + +# TODO: make this completely parallel to the llama4 generation.py file and share common code +# from llama-models also +class Llama3: + @staticmethod + def build( + ckpt_dir: str, + max_seq_len: int, + max_batch_size: int, + world_size: Optional[int] = None, + tokenizer_path: Optional[str] = None, + seed: int = 1, + device: str = "cuda", + ): + device = torch.device(device) + if ( + device.type == "cuda" + and not torch.cuda.is_available() + or device.type == "xpu" + and not torch.xpu.is_available() + ): + raise RuntimeError(f"PyTorch backend for {device.type} device type is not available") + + if not torch.distributed.is_initialized(): + if device.type == "cuda": + torch.distributed.init_process_group("nccl") + else: + torch.distributed.init_process_group("gloo") + + if not model_parallel_is_initialized(): + if world_size is None: + world_size = int(os.environ.get("WORLD_SIZE", 1)) + initialize_model_parallel(world_size) + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + if device.type == "cuda": + torch.cuda.set_device(local_rank) + elif device.type == "xpu": + torch.xpu.set_device(local_rank) + + torch.manual_seed(seed) + + if local_rank > 0: + sys.stdout = open(os.devnull, "w") + + start_time = time.time() + + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" + assert world_size == len(checkpoints), ( + f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}" + ) + ckpt_path = checkpoints[get_model_parallel_rank()] + checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + **params, + ) + if tokenizer_path: + tokenizer = Tokenizer(model_path=tokenizer_path) + else: + tokenizer = Tokenizer.get_instance() + + assert model_args.vocab_size == tokenizer.n_words + torch.set_default_device(device) + if device.type == "cuda": + if torch.cuda.is_bf16_supported(): + torch.set_default_dtype(torch.bfloat16) + else: + torch.set_default_dtype(torch.half) + elif device.type == "xpu": + if torch.xpu.is_bf16_supported(): + torch.set_default_dtype(torch.bfloat16) + else: + torch.set_default_dtype(torch.half) + else: + torch.set_default_dtype(torch.half) + + if model_args.vision_chunk_size > 0: + from .multimodal.model import CrossAttentionTransformer + + model = CrossAttentionTransformer(model_args) + model.setup_cache(model_args.max_batch_size, torch.get_default_dtype()) + else: + model = Transformer(model_args) + model.load_state_dict(checkpoint, strict=True) + model.to(device) + print(f"Loaded in {time.time() - start_time:.2f} seconds") + + return Llama(model, tokenizer, model_args) + + def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs): + self.args = args + self.model = model + self.tokenizer = tokenizer + self.formatter = ChatFormat(tokenizer) + + @torch.inference_mode() + def generate( + self, + model_input: LLMInput, + max_gen_len: int, + temperature: float = 0.6, + top_p: float = 0.9, + logprobs: bool = False, + echo: bool = False, + print_model_input: bool = False, + logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ) -> Generator: + params = self.model.params + + if print_model_input: + tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens] + cprint( + "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n", + "red", + ) + prompt_tokens = [model_input.tokens] + + bsz = 1 + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + min_prompt_len = min(len(t) for t in prompt_tokens) + max_prompt_len = max(len(t) for t in prompt_tokens) + + if max_prompt_len >= params.max_seq_len: + cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red") + return + + total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) + + is_vision = not isinstance(self.model, Transformer) + if is_vision: + images = model_input.vision.images if model_input.vision is not None else [] + mask = model_input.vision.mask if model_input.vision is not None else [] + + # the method works for bsz > 1 so add a batch dimension + xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( + batch_images=[images], + batch_masks=[mask], + total_len=total_len, + ) + + pad_id = self.tokenizer.pad_id + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long) + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long) + if logprobs: + token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + + prev_pos = 0 + eos_reached = torch.tensor([False] * bsz) + input_text_mask = tokens != pad_id + + if echo: + for i, t in enumerate(model_input.tokens): + yield TokenResult( + token=t, + text=self.tokenizer.decode([t]), + logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None), + ) + + stop_tokens = torch.tensor(self.tokenizer.stop_tokens) + for cur_pos in range(min_prompt_len, total_len): + if is_vision: + position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long) + text_only_inference = model_input.vision is None + logits = self.model.forward( + position_ids, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + text_only_inference, + ) + else: + logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + + if logits_processor is not None: + logits = logits_processor(tokens[:, :cur_pos], logits) + + if temperature > 0: + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits[:, -1], dim=-1) + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) + tokens[:, cur_pos] = next_token + + target = tokens[:, prev_pos + 1 : cur_pos + 1] + if is_vision: + # the logits space (num_classes) is designed to never contain a media_token + # however our input token stream does contain them. we need to nuke them here + # or else the CUDA kernels will crash with an illegal memory access + vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256] + masks = [target.eq(t) for t in vision_tokens] + if len(masks) > 1: + mask = torch.logical_or(*masks) + else: + mask = masks[0] + target[mask] = 0 + + if logprobs: + token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( + input=logits.transpose(1, 2), + target=target, + reduction="none", + ignore_index=pad_id, + ) + eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens)) + yield TokenResult( + token=next_token[0].item(), + text=self.tokenizer.decode(next_token.tolist()), + logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None), + ) + + prev_pos = cur_pos + if all(eos_reached): + break + + def text_completion( + self, + content: RawContent, + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + echo: bool = False, + ) -> CompletionPrediction: + if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len: + max_gen_len = self.model.params.max_seq_len - 1 + + model_input = self.formatter.encode_content(content) + + tokens = [] + token_logprobs = [] + decoded_tokens = [] + for result in self.generate( + model_input=model_input, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + echo=echo, + ): + tokens.append(result.token) + if logprobs: + decoded_tokens.append(result.text) + token_logprobs.append(result.logprobs) + + generation = self.tokenizer.decode(tokens) + if logprobs: + return CompletionPrediction( + generation=generation, + logprobs=token_logprobs, + decoded_tokens=decoded_tokens, + ) + + return CompletionPrediction(generation=generation) + + def chat_completion( + self, + messages: List[RawMessage], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, + echo: bool = False, + ) -> ChatPrediction: + if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len: + max_gen_len = self.model.params.max_seq_len - 1 + + tokens = [] + token_logprobs = [] + decoded_tokens = [] + + stop_reason = None + for result in self.generate( + model_input=self.formatter.encode_dialog_prompt(messages, tool_prompt_format), + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + echo=echo, + ): + tokens.append(result.token) + if result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + elif result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + + if logprobs: + decoded_tokens.append(result.text) + token_logprobs.append(result.logprobs) + + if stop_reason is None: + stop_reason = StopReason.out_of_tokens + + message = self.formatter.decode_assistant_message(tokens, stop_reason) + + if logprobs: + return ChatPrediction( + generation=message, + logprobs=token_logprobs, + decoded_tokens=decoded_tokens, + ) + + return ChatPrediction(generation=message) + + def chat_completion_raw( + self, + messages: List[RawMessage], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, + ) -> List[int]: + if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len: + max_gen_len = self.model.params.max_seq_len - 1 + + output_tokens = [] + model_input = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) + input_tokens = model_input.tokens + for result in self.generate( + model_input=model_input, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=False, + ): + output_tokens.append(result.token) + + return input_tokens, output_tokens + + def text_completion_raw( + self, + content: RawContent, + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + ): + if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len: + max_gen_len = self.model.params.max_seq_len - 1 + + model_input = self.formatter.encode_content(content) + input_tokens = model_input.tokens + + output_tokens = [] + for result in self.generate( + model_input=model_input, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=False, + ): + output_tokens.append(result.token) + + return input_tokens, output_tokens + + +def sample_top_p(probs, p): + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token diff --git a/llama_stack/models/llama/llama3/interface.py b/llama_stack/models/llama/llama3/interface.py index 2579ab6c8..8684237df 100644 --- a/llama_stack/models/llama/llama3/interface.py +++ b/llama_stack/models/llama/llama3/interface.py @@ -16,7 +16,7 @@ from typing import List, Optional from termcolor import colored -from llama_stack.models.llama.datatypes import ( +from ..datatypes import ( BuiltinTool, RawMessage, StopReason, @@ -24,7 +24,6 @@ from llama_stack.models.llama.datatypes import ( ToolDefinition, ToolPromptFormat, ) - from . import template_data from .chat_format import ChatFormat from .prompt_templates import ( diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/model.py b/llama_stack/models/llama/llama3/model.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama3/model.py rename to llama_stack/models/llama/llama3/model.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/__init__.py b/llama_stack/models/llama/llama3/multimodal/__init__.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/__init__.py rename to llama_stack/models/llama/llama3/multimodal/__init__.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/encoder_utils.py b/llama_stack/models/llama/llama3/multimodal/encoder_utils.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/encoder_utils.py rename to llama_stack/models/llama/llama3/multimodal/encoder_utils.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/image_transform.py b/llama_stack/models/llama/llama3/multimodal/image_transform.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/image_transform.py rename to llama_stack/models/llama/llama3/multimodal/image_transform.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model.py b/llama_stack/models/llama/llama3/multimodal/model.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model.py rename to llama_stack/models/llama/llama3/multimodal/model.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/utils.py b/llama_stack/models/llama/llama3/multimodal/utils.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/utils.py rename to llama_stack/models/llama/llama3/multimodal/utils.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/quantization/loader.py b/llama_stack/models/llama/llama3/quantization/loader.py similarity index 98% rename from llama_stack/providers/inline/inference/meta_reference/llama3/quantization/loader.py rename to llama_stack/models/llama/llama3/quantization/loader.py index 5109130b4..f4d94c382 100644 --- a/llama_stack/providers/inline/inference/meta_reference/llama3/quantization/loader.py +++ b/llama_stack/models/llama/llama3/quantization/loader.py @@ -20,16 +20,16 @@ from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from llama_stack.apis.inference import QuantizationType from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat from llama_stack.models.llama.sku_list import resolve_model -from llama_stack.providers.inline.inference.meta_reference.quantize_impls import ( + +from ...config import MetaReferenceQuantizedInferenceConfig +from ...datatypes import CheckpointQuantizationFormat +from ...quantize_impls import ( Fp8ScaledWeights, ffn_swiglu, load_fp8, quantize_fp8, ) - -from ...config import MetaReferenceQuantizedInferenceConfig from ..args import ModelArgs from ..model import Transformer, TransformerBlock @@ -292,7 +292,6 @@ def _prepare_model_int4_weight_int8_dynamic_activation( def convert_to_int4_quantized_model( model: Transformer, model_args: ModelArgs, - config: MetaReferenceQuantizedInferenceConfig, ) -> Transformer: """Convert the model to int4 quantized model.""" diff --git a/llama_stack/models/llama/llama3/template_data.py b/llama_stack/models/llama/llama3/template_data.py index 076b4adb4..efca8397e 100644 --- a/llama_stack/models/llama/llama3/template_data.py +++ b/llama_stack/models/llama/llama3/template_data.py @@ -12,8 +12,7 @@ # the top-level of this source tree. -from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall - +from ..datatypes import BuiltinTool, StopReason, ToolCall from .prompt_templates import ( BuiltinToolGenerator, JsonCustomToolGenerator, diff --git a/llama_stack/models/llama/llama3/tool_utils.py b/llama_stack/models/llama/llama3/tool_utils.py index 71018898c..fc8287eb6 100644 --- a/llama_stack/models/llama/llama3/tool_utils.py +++ b/llama_stack/models/llama/llama3/tool_utils.py @@ -16,7 +16,8 @@ import re from typing import Optional, Tuple from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat + +from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat logger = get_logger(name=__name__, category="inference") diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/args.py b/llama_stack/models/llama/llama4/args.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama4/args.py rename to llama_stack/models/llama/llama4/args.py diff --git a/llama_stack/models/llama/llama4/chat_format.py b/llama_stack/models/llama/llama4/chat_format.py index c873012d6..ebae2b8e5 100644 --- a/llama_stack/models/llama/llama4/chat_format.py +++ b/llama_stack/models/llama/llama4/chat_format.py @@ -12,8 +12,7 @@ from typing import Dict, List, Optional, Tuple import torch from PIL import Image as PIL_Image -# TODO: either fork these or move them to the common package -from llama_stack.models.llama.datatypes import ( +from ..datatypes import ( BuiltinTool, RawContent, RawMediaItem, @@ -24,16 +23,13 @@ from llama_stack.models.llama.datatypes import ( ToolCall, ToolPromptFormat, ) -from llama_stack.models.llama.llama3.tool_utils import ToolUtils -from llama_stack.providers.inline.inference.meta_reference.llama4.args import VisionArgs -from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import ( - LLMInput, -) -from llama_stack.providers.inline.inference.meta_reference.llama4.preprocess import ( +from ..llama3.tool_utils import ToolUtils +from .args import VisionArgs +from .datatypes import LLMInput +from .preprocess import ( ResizeNormalizeImageTransform, VariableSizeImageTransform, ) - from .tokenizer import Tokenizer diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/datatypes.py b/llama_stack/models/llama/llama4/datatypes.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama4/datatypes.py rename to llama_stack/models/llama/llama4/datatypes.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/ffn.py b/llama_stack/models/llama/llama4/ffn.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama4/ffn.py rename to llama_stack/models/llama/llama4/ffn.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/generation.py b/llama_stack/models/llama/llama4/generation.py similarity index 98% rename from llama_stack/providers/inline/inference/meta_reference/llama4/generation.py rename to llama_stack/models/llama/llama4/generation.py index de900ce8d..9c516d967 100644 --- a/llama_stack/providers/inline/inference/meta_reference/llama4/generation.py +++ b/llama_stack/models/llama/llama4/generation.py @@ -23,17 +23,16 @@ from fairscale.nn.model_parallel.initialize import ( ) from termcolor import cprint -from llama_stack.models.llama.llama4.chat_format import ( +from ..common import TokenResult +from .args import ModelArgs +from .chat_format import ( ChatFormat, RawContent, RawMessage, ) -from llama_stack.models.llama.llama4.tokenizer import Tokenizer - -from ..common import TokenResult -from .args import ModelArgs from .datatypes import LLMInput, MaskedEmbedding, TransformerInput from .model import Transformer +from .tokenizer import Tokenizer torch.serialization.add_safe_globals([io.BytesIO, codecs.encode]) diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/model.py b/llama_stack/models/llama/llama4/model.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama4/model.py rename to llama_stack/models/llama/llama4/model.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/moe.py b/llama_stack/models/llama/llama4/moe.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama4/moe.py rename to llama_stack/models/llama/llama4/moe.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/preprocess.py b/llama_stack/models/llama/llama4/preprocess.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama4/preprocess.py rename to llama_stack/models/llama/llama4/preprocess.py diff --git a/llama_stack/models/llama/llama4/prompts.py b/llama_stack/models/llama/llama4/prompts.py index 97f573ef8..d4e48e80a 100644 --- a/llama_stack/models/llama/llama4/prompts.py +++ b/llama_stack/models/llama/llama4/prompts.py @@ -16,8 +16,8 @@ from io import BytesIO from pathlib import Path from typing import List -from llama_stack.models.llama.datatypes import RawMediaItem, RawMessage, RawTextItem -from llama_stack.models.llama.prompt_format import ( +from ..datatypes import RawMediaItem, RawMessage, RawTextItem +from ..prompt_format import ( Llama4UseCase, TextCompletionContent, UseCase, diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/quantization/loader.py b/llama_stack/models/llama/llama4/quantization/loader.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama4/quantization/loader.py rename to llama_stack/models/llama/llama4/quantization/loader.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/vision/embedding.py b/llama_stack/models/llama/llama4/vision/embedding.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama4/vision/embedding.py rename to llama_stack/models/llama/llama4/vision/embedding.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/vision/encoder.py b/llama_stack/models/llama/llama4/vision/encoder.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama4/vision/encoder.py rename to llama_stack/models/llama/llama4/vision/encoder.py diff --git a/llama_stack/providers/inline/inference/meta_reference/quantize_impls.py b/llama_stack/models/llama/quantize_impls.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/quantize_impls.py rename to llama_stack/models/llama/quantize_impls.py diff --git a/llama_stack/providers/inline/inference/meta_reference/generators.py b/llama_stack/providers/inline/inference/meta_reference/generators.py index 4b0ed7ecd..809351164 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generators.py +++ b/llama_stack/providers/inline/inference/meta_reference/generators.py @@ -22,7 +22,9 @@ from llama_stack.models.llama.datatypes import ( SamplingParams, TopPSamplingStrategy, ) +from llama_stack.models.llama.llama3.generation import Llama3 from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer +from llama_stack.models.llama.llama4.generation import Llama4 from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, @@ -33,8 +35,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .common import model_checkpoint_dir from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig from .inference import resolve_model -from .llama3.generation import Llama3 -from .llama4.generation import Llama4 Tokenizer = Llama4Tokenizer | Llama3Tokenizer @@ -212,14 +212,34 @@ class Llama3Generator: model_id: str, llama_model: Model, ): + if config.checkpoint_dir and config.checkpoint_dir != "null": + ckpt_dir = config.checkpoint_dir + else: + resolved_model = resolve_model(model_id) + if resolved_model is None: + # if the model is not a native llama model, get the default checkpoint_dir based on model id + ckpt_dir = model_checkpoint_dir(model_id) + else: + # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value + ckpt_dir = model_checkpoint_dir(resolved_model.descriptor()) + + if isinstance(config, MetaReferenceQuantizedInferenceConfig): + if isinstance(config.quantization, Fp8QuantizationConfig): + quantization_mode = "fp8_mixed" + elif isinstance(config.quantization, Int4QuantizationConfig): + quantization_mode = "int4_mixed" + else: + raise ValueError(f"Unsupported quantization mode {config.quantization}") + else: + quantization_mode = None + self.inner_generator = Llama3.build( - config=config, - model_id=model_id, - llama_model=llama_model, + ckpt_dir=ckpt_dir, + max_seq_len=config.max_seq_len, + max_batch_size=config.max_batch_size, + world_size=llama_model.pth_file_count, + quantization_mode=quantization_mode, ) - self.tokenizer = self.inner_generator.tokenizer - self.args = self.inner_generator.args - self.formatter = self.inner_generator.formatter def completion( self, diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/generation.py b/llama_stack/providers/inline/inference/meta_reference/llama3/generation.py deleted file mode 100644 index 3805e4310..000000000 --- a/llama_stack/providers/inline/inference/meta_reference/llama3/generation.py +++ /dev/null @@ -1,346 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -import json -import os -import sys -import time -from pathlib import Path -from typing import Callable, Generator, Optional, Union - -import torch -import torch.nn.functional as F -from fairscale.nn.model_parallel.initialize import ( - get_model_parallel_rank, - initialize_model_parallel, - model_parallel_is_initialized, -) - -from llama_stack.apis.inference import ( - Fp8QuantizationConfig, - Int4QuantizationConfig, -) -from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import Model -from llama_stack.models.llama.llama3.chat_format import ChatFormat, LLMInput -from llama_stack.models.llama.llama3.tokenizer import Tokenizer -from llama_stack.models.llama.sku_list import resolve_model - -from ..common import TokenResult, model_checkpoint_dir -from ..config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig -from .args import ModelArgs -from .model import Transformer -from .multimodal.model import CrossAttentionTransformer - -log = get_logger(__name__, category="inference") - - -class Llama3: - @staticmethod - def build( - config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig], - model_id: str, - llama_model: Model, - ): - """ - Build a Llama instance by initializing and loading a model checkpoint. - - Note: - This method initializes the distributed process group, sets the device to CUDA, - and loads the pre-trained model and tokenizer. - """ - if "DEVICE" in os.environ: - device = os.environ.get("DEVICE") - if device == "cuda": - assert torch.cuda.is_available(), "PyTorch CUDA backend not available" - if device == "xpu": - assert torch.xpu.is_available(), "PyTorch XPU backend not available" - else: - if torch.cuda.is_available(): - device = "cuda" - elif torch.xpu.is_available(): - device = "xpu" - else: - device = "cpu" - log.info(f"Using {device} device") - - llama_model_id = llama_model.core_model_id.value - if not torch.distributed.is_initialized(): - if device == "cuda": - torch.distributed.init_process_group("nccl") - else: - torch.distributed.init_process_group("gloo") - - model_parallel_size = llama_model.pth_file_count - - if not model_parallel_is_initialized(): - initialize_model_parallel(model_parallel_size) - - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - if device == "cuda": - torch.cuda.set_device(local_rank) - elif device == "xpu": - torch.xpu.set_device(local_rank) - - # seed must be the same in all processes - if config.torch_seed is not None: - torch.manual_seed(config.torch_seed) - - if local_rank > 0: - sys.stdout = open(os.devnull, "w") - - start_time = time.time() - if config.checkpoint_dir and config.checkpoint_dir != "null": - ckpt_dir = config.checkpoint_dir - else: - resolved_model = resolve_model(model_id) - if resolved_model is None: - # if the model is not a native llama model, get the default checkpoint_dir based on model id - ckpt_dir = model_checkpoint_dir(model_id) - else: - # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value - ckpt_dir = model_checkpoint_dir(resolved_model.descriptor()) - - checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) - assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" - assert model_parallel_size == len(checkpoints), ( - f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" - ) - ckpt_path = checkpoints[get_model_parallel_rank()] - state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) - with open(Path(ckpt_dir) / "params.json", "r") as f: - params = json.loads(f.read()) - - if "model" in params: - params = params["model"] - - model_args: ModelArgs = ModelArgs( - max_seq_len=config.max_seq_len, - max_batch_size=config.max_batch_size, - **params, - ) - - tokenizer = Tokenizer.get_instance() - assert model_args.vocab_size == tokenizer.n_words, ( - f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" - ) - - if isinstance(config, MetaReferenceQuantizedInferenceConfig): - if isinstance(config.quantization, Fp8QuantizationConfig): - from .quantization.loader import convert_to_fp8_quantized_model - - # load on CPU in bf16 so that fp8 conversion does not find an - # unexpected (fp32, e.g.) datatype - torch.set_default_tensor_type(torch.BFloat16Tensor) - if model_args.vision_chunk_size > 0: - model = CrossAttentionTransformer(model_args) - model.setup_cache(model_args.max_batch_size, torch.bfloat16) - else: - model = Transformer(model_args) - model.load_state_dict(state_dict, strict=False) - model = convert_to_fp8_quantized_model(model, config, ckpt_dir) - elif isinstance(config.quantization, Int4QuantizationConfig): - from .quantization.loader import convert_to_int4_quantized_model - - model = Transformer(model_args) - model = convert_to_int4_quantized_model(model, model_args, config) - model.load_state_dict(state_dict, strict=True) - - if model_args.quantization_args is not None and model_args.quantization_args.spinquant: - # Add a wrapper for adding hadamard transform for spinquant. - # This needs to be done after loading the state dict otherwise an error will be raised while - # loading the state dict. - from ..hadamard_utils import ( - add_hadamard_transform_for_spinquant, - ) - - add_hadamard_transform_for_spinquant(model) - else: - raise NotImplementedError("Currently int4 and fp8 are the only supported quantization methods.") - else: - if device == "cuda": - if torch.cuda.is_bf16_supported(): - torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) - else: - torch.set_default_tensor_type(torch.cuda.HalfTensor) - else: - torch.set_default_device(device) - if device == "xpu" and torch.xpu.is_bf16_supported(): - torch.set_default_dtype(torch.bfloat16) - else: - torch.set_default_dtype(torch.half) - if model_args.vision_chunk_size > 0: - model = CrossAttentionTransformer(model_args) - model.setup_cache(model_args.max_batch_size, torch.bfloat16) - else: - model = Transformer(model_args) - model.load_state_dict(state_dict, strict=False) - - model.to(device) - - log.info(f"Loaded in {time.time() - start_time:.2f} seconds") - return Llama3(model, tokenizer, model_args, llama_model_id) - - def __init__( - self, - model: Transformer, - tokenizer: Tokenizer, - args: ModelArgs, - llama_model: str, - ): - self.args = args - self.model = model - self.tokenizer = tokenizer - self.formatter = ChatFormat(tokenizer) - self.llama_model = llama_model - - @torch.inference_mode() - def generate( - self, - model_input: LLMInput, - max_gen_len: int, - temperature: float = 0.6, - top_p: float = 0.9, - logprobs: bool = False, - echo: bool = False, - print_input_tokens: bool = False, - logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - ) -> Generator: - params = self.model.params - - if print_input_tokens: - input_tokens = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens] - log.info("Input to model -> " + self.tokenizer.decode(input_tokens)) - prompt_tokens = [model_input.tokens] - - bsz = 1 - assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) - - min_prompt_len = min(len(t) for t in prompt_tokens) - max_prompt_len = max(len(t) for t in prompt_tokens) - - if max_prompt_len >= params.max_seq_len: - log.error(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}") - return - - total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) - - is_vision = isinstance(self.model, CrossAttentionTransformer) - if is_vision: - images = model_input.vision.images if model_input.vision is not None else [] - mask = model_input.vision.mask if model_input.vision is not None else [] - - # the method works for bsz > 1 so add a batch dimension - xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( - batch_images=[images], - batch_masks=[mask], - total_len=total_len, - ) - - pad_id = self.tokenizer.pad_id - tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long) - for k, t in enumerate(prompt_tokens): - tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long) - if logprobs: - token_logprobs = torch.zeros_like(tokens) - - prev_pos = 0 - eos_reached = torch.tensor([False] * bsz) - input_text_mask = tokens != pad_id - if min_prompt_len == total_len: - # TODO(ashwin): unify this branch with the one below and figure out multimodal crap - logits = self.model.forward(tokens, prev_pos) - token_logprobs = -F.cross_entropy( - input=logits.transpose(1, 2), - target=tokens, - reduction="none", - ignore_index=pad_id, - ) - - stop_tokens = torch.tensor(self.tokenizer.stop_tokens) - for cur_pos in range(min_prompt_len, total_len): - if is_vision: - position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long) - logits = self.model.forward( - position_ids, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, - ) - else: - logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) - - if logits_processor is not None: - logits = logits_processor(tokens[:, :cur_pos], logits) - - if temperature > 0: - probs = torch.softmax(logits[:, -1] / temperature, dim=-1) - next_token = sample_top_p(probs, top_p) - else: - next_token = torch.argmax(logits[:, -1], dim=-1) - - next_token = next_token.reshape(-1) - # only replace token if prompt has already been generated - next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) - tokens[:, cur_pos] = next_token - - target = tokens[:, prev_pos + 1 : cur_pos + 1] - if is_vision: - # the logits space (num_classes) is designed to never contain a media_token - # however our input token stream does contain them. we need to nuke them here - # or else the CUDA kernels will crash with an illegal memory access - vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256] - masks = [target.eq(t) for t in vision_tokens] - if len(masks) > 1: - mask = torch.logical_or(*masks) - else: - mask = masks[0] - target[mask] = 0 - - if logprobs: - token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( - input=logits.transpose(1, 2), - target=tokens[:, prev_pos + 1 : cur_pos + 1], - reduction="none", - ignore_index=pad_id, - ) - eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens)) - yield TokenResult( - token=next_token[0].item(), - text=self.tokenizer.decode(next_token.tolist()), - logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None), - ) - - prev_pos = cur_pos - if all(eos_reached): - break - - -def sample_top_p(probs, p): - """ - Perform top-p (nucleus) sampling on a probability distribution. - - Args: - probs (torch.Tensor): Probability distribution tensor. - p (float): Probability threshold for top-p sampling. - - Returns: - torch.Tensor: Sampled token indices. - - Note: - Top-p sampling selects the smallest set of tokens whose cumulative probability mass - exceeds the threshold p. The distribution is renormalized based on the selected tokens. - """ - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - return next_token