From 530d4bdfe130ace4b31f09cb0334195928d4bc08 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 7 Apr 2025 15:03:58 -0700 Subject: [PATCH] refactor: move all llama code to models/llama out of meta reference (#1887) # What does this PR do? Move around bits. This makes the copies from llama-models _much_ easier to maintain and ensures we don't entangle meta-reference specific tidbits into llama-models code even by accident. Also, kills the meta-reference-quantized-gpu distro and rolls quantization deps into meta-reference-gpu. ## Test Plan ``` LLAMA_MODELS_DEBUG=1 \ with-proxy llama stack run meta-reference-gpu \ --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \ --env INFERENCE_CHECKPOINT_DIR= \ --env MODEL_PARALLEL_SIZE=4 \ --env QUANTIZATION_TYPE=fp8_mixed ``` Start a server with and without quantization. Point integration tests to it using: ``` pytest -s -v tests/integration/inference/test_text_inference.py \ --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct ``` --- llama_stack/apis/inference/inference.py | 63 ++- llama_stack/cli/download.py | 2 +- llama_stack/cli/model/describe.py | 11 - llama_stack/cli/model/prompt_format.py | 2 +- llama_stack/cli/model/safety_models.py | 5 +- llama_stack/models/llama/checkpoint.py | 164 ++++++++ llama_stack/models/llama/datatypes.py | 339 ++-------------- .../llama}/hadamard_utils.py | 0 .../llama}/llama3/args.py | 7 - .../models/llama/llama3/chat_format.py | 10 +- llama_stack/models/llama/llama3/generation.py | 367 ++++++++++++++++++ llama_stack/models/llama/llama3/interface.py | 3 +- .../llama}/llama3/model.py | 20 +- .../llama}/llama3/multimodal/__init__.py | 0 .../llama}/llama3/multimodal/encoder_utils.py | 0 .../llama3/multimodal/image_transform.py | 0 .../llama}/llama3/multimodal/model.py | 73 ++-- .../llama}/llama3/multimodal/utils.py | 0 .../llama3/prompt_templates/system_prompts.py | 2 +- .../llama/llama3/quantization}/__init__.py | 2 - .../llama}/llama3/quantization/loader.py | 73 ++-- .../models/llama/llama3/template_data.py | 3 +- llama_stack/models/llama/llama3/tokenizer.py | 10 - llama_stack/models/llama/llama3/tool_utils.py | 3 +- llama_stack/models/llama/llama3_2/__init__.py | 7 - .../models/llama/llama3_2/prompts_text.py | 6 - .../models/llama/llama3_2/prompts_vision.py | 7 - .../llama}/llama4/args.py | 7 - .../models/llama/llama4/chat_format.py | 20 +- .../llama}/llama4/datatypes.py | 7 - .../llama}/llama4/ffn.py | 0 .../llama}/llama4/generation.py | 153 ++++---- .../llama}/llama4/model.py | 11 - .../llama}/llama4/moe.py | 16 +- .../llama}/llama4/preprocess.py | 0 llama_stack/models/llama/llama4/prompts.py | 11 +- .../llama/llama4/quantization/__init__.py | 5 + .../llama}/llama4/quantization/loader.py | 98 +++-- llama_stack/models/llama/llama4/tokenizer.py | 25 +- .../llama}/llama4/vision/embedding.py | 7 - .../llama}/llama4/vision/encoder.py | 0 llama_stack/models/llama/prompt_format.py | 58 ++- .../llama}/quantize_impls.py | 0 llama_stack/models/llama/sku_list.py | 55 +-- llama_stack/models/llama/sku_types.py | 229 +++++++++++ .../agents/meta_reference/agent_instance.py | 2 +- .../inference/meta_reference/__init__.py | 6 +- .../inline/inference/meta_reference/common.py | 9 - .../inline/inference/meta_reference/config.py | 26 +- .../inference/meta_reference/generators.py | 93 +++-- .../inference/meta_reference/inference.py | 12 +- .../meta_reference/llama3/generation.py | 346 ----------------- .../meta_reference/parallel_utils.py | 5 +- .../inline/inference/vllm/openai_utils.py | 3 +- .../providers/inline/inference/vllm/vllm.py | 4 +- .../post_training/torchtune/common/utils.py | 2 +- .../inline/safety/llama_guard/llama_guard.py | 3 +- llama_stack/providers/registry/inference.py | 9 +- .../remote/inference/bedrock/models.py | 2 +- .../remote/inference/cerebras/cerebras.py | 2 +- .../remote/inference/cerebras/models.py | 2 +- .../remote/inference/databricks/databricks.py | 2 +- .../remote/inference/fireworks/models.py | 2 +- .../remote/inference/nvidia/models.py | 2 +- .../remote/inference/nvidia/nvidia.py | 6 +- .../remote/inference/nvidia/openai_utils.py | 4 +- .../remote/inference/ollama/models.py | 2 +- .../remote/inference/sambanova/models.py | 2 +- .../remote/inference/sambanova/sambanova.py | 6 +- .../remote/inference/together/models.py | 2 +- .../remote/post_training/nvidia/models.py | 2 +- llama_stack/providers/tests/report.py | 2 +- .../providers/utils/inference/__init__.py | 2 +- .../utils/inference/openai_compat.py | 8 +- .../utils/inference/prompt_adapter.py | 3 +- llama_stack/templates/dependencies.json | 45 +-- .../meta-reference-gpu/run-with-safety.yaml | 6 + .../templates/meta-reference-gpu/run.yaml | 3 + .../meta-reference-quantized-gpu/build.yaml | 32 -- .../doc_template.md | 113 ------ .../meta_reference.py | 115 ------ .../meta-reference-quantized-gpu/run.yaml | 134 ------- pyproject.toml | 6 +- scripts/generate_prompt_format.py | 42 +- tests/integration/report.py | 2 +- 85 files changed, 1267 insertions(+), 1683 deletions(-) create mode 100644 llama_stack/models/llama/checkpoint.py rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/hadamard_utils.py (100%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama3/args.py (88%) create mode 100644 llama_stack/models/llama/llama3/generation.py rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama3/model.py (94%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama3/multimodal/__init__.py (100%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama3/multimodal/encoder_utils.py (100%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama3/multimodal/image_transform.py (100%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama3/multimodal/model.py (95%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama3/multimodal/utils.py (100%) rename llama_stack/{templates/meta-reference-quantized-gpu => models/llama/llama3/quantization}/__init__.py (74%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama3/quantization/loader.py (84%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama4/args.py (91%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama4/datatypes.py (85%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama4/ffn.py (100%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama4/generation.py (72%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama4/model.py (97%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama4/moe.py (87%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama4/preprocess.py (100%) create mode 100644 llama_stack/models/llama/llama4/quantization/__init__.py rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama4/quantization/loader.py (70%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama4/vision/embedding.py (96%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/llama4/vision/encoder.py (100%) rename llama_stack/{providers/inline/inference/meta_reference => models/llama}/quantize_impls.py (100%) create mode 100644 llama_stack/models/llama/sku_types.py delete mode 100644 llama_stack/providers/inline/inference/meta_reference/llama3/generation.py delete mode 100644 llama_stack/templates/meta-reference-quantized-gpu/build.yaml delete mode 100644 llama_stack/templates/meta-reference-quantized-gpu/doc_template.md delete mode 100644 llama_stack/templates/meta-reference-quantized-gpu/meta_reference.py delete mode 100644 llama_stack/templates/meta-reference-quantized-gpu/run.yaml diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 1d4012c19..e59132e33 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -25,15 +25,64 @@ from llama_stack.apis.models import Model from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.models.llama.datatypes import ( BuiltinTool, - SamplingParams, StopReason, ToolCall, ToolDefinition, + ToolParamDefinition, ToolPromptFormat, ) from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, register_schema, webmethod +register_schema(ToolCall) +register_schema(ToolParamDefinition) +register_schema(ToolDefinition) + + +@json_schema_type +class GreedySamplingStrategy(BaseModel): + type: Literal["greedy"] = "greedy" + + +@json_schema_type +class TopPSamplingStrategy(BaseModel): + type: Literal["top_p"] = "top_p" + temperature: Optional[float] = Field(..., gt=0.0) + top_p: Optional[float] = 0.95 + + +@json_schema_type +class TopKSamplingStrategy(BaseModel): + type: Literal["top_k"] = "top_k" + top_k: int = Field(..., ge=1) + + +SamplingStrategy = Annotated[ + Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy], + Field(discriminator="type"), +] +register_schema(SamplingStrategy, name="SamplingStrategy") + + +@json_schema_type +class SamplingParams(BaseModel): + """Sampling parameters. + + :param strategy: The sampling strategy. + :param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of + your prompt plus max_tokens cannot exceed the model's context length. + :param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens + based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + :param stop: Up to 4 sequences where the API will stop generating further tokens. + The returned text will not contain the stop sequence. + """ + + strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy) + + max_tokens: Optional[int] = 0 + repetition_penalty: Optional[float] = 1.0 + stop: Optional[List[str]] = None + class LogProbConfig(BaseModel): """ @@ -48,18 +97,18 @@ class QuantizationType(Enum): """Type of model quantization to run inference with. :cvar bf16: BFloat16 typically this means _no_ quantization - :cvar fp8: 8-bit floating point quantization - :cvar int4: 4-bit integer quantization + :cvar fp8_mixed: 8-bit floating point quantization with mixed precision + :cvar int4_mixed: 4-bit integer quantization with mixed precision """ bf16 = "bf16" - fp8 = "fp8" - int4 = "int4" + fp8_mixed = "fp8_mixed" + int4_mixed = "int4_mixed" @json_schema_type class Fp8QuantizationConfig(BaseModel): - type: Literal["fp8"] = "fp8" + type: Literal["fp8_mixed"] = "fp8_mixed" @json_schema_type @@ -75,7 +124,7 @@ class Int4QuantizationConfig(BaseModel): :param scheme: Quantization scheme to use. Defaults to "int4_weight_int8_dynamic_activation" """ - type: Literal["int4"] = "int4" + type: Literal["int4_mixed"] = "int4_mixed" scheme: Optional[str] = "int4_weight_int8_dynamic_activation" diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py index fc3e7008f..9694bf22d 100644 --- a/llama_stack/cli/download.py +++ b/llama_stack/cli/download.py @@ -29,8 +29,8 @@ from rich.progress import ( from termcolor import cprint from llama_stack.cli.subcommand import Subcommand -from llama_stack.models.llama.datatypes import Model from llama_stack.models.llama.sku_list import LlamaDownloadInfo +from llama_stack.models.llama.sku_types import Model class Download(Subcommand): diff --git a/llama_stack/cli/model/describe.py b/llama_stack/cli/model/describe.py index f347bdf8d..62dde36e8 100644 --- a/llama_stack/cli/model/describe.py +++ b/llama_stack/cli/model/describe.py @@ -63,17 +63,6 @@ class ModelDescribe(Subcommand): ("Model params.json", json.dumps(model.arch_args, indent=4)), ] - if model.recommended_sampling_params is not None: - sampling_params = model.recommended_sampling_params.model_dump() - for k in ("max_tokens", "repetition_penalty"): - del sampling_params[k] - rows.append( - ( - "Recommended sampling params", - json.dumps(sampling_params, indent=4), - ) - ) - print_table( rows, headers, diff --git a/llama_stack/cli/model/prompt_format.py b/llama_stack/cli/model/prompt_format.py index 3ce77655b..673487812 100644 --- a/llama_stack/cli/model/prompt_format.py +++ b/llama_stack/cli/model/prompt_format.py @@ -11,7 +11,7 @@ from pathlib import Path from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.table import print_table -from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family +from llama_stack.models.llama.sku_types import CoreModelId, ModelFamily, is_multimodal, model_family ROOT_DIR = Path(__file__).parent.parent.parent diff --git a/llama_stack/cli/model/safety_models.py b/llama_stack/cli/model/safety_models.py index c81783f60..131d055aa 100644 --- a/llama_stack/cli/model/safety_models.py +++ b/llama_stack/cli/model/safety_models.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any, Dict from pydantic import BaseModel, ConfigDict, Field -from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams from llama_stack.models.llama.sku_list import LlamaDownloadInfo +from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat class PromptGuardModel(BaseModel): @@ -23,7 +23,6 @@ class PromptGuardModel(BaseModel): is_instruct_model: bool = False quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16 arch_args: Dict[str, Any] = Field(default_factory=dict) - recommended_sampling_params: Optional[SamplingParams] = None def descriptor(self) -> str: return self.model_id diff --git a/llama_stack/models/llama/checkpoint.py b/llama_stack/models/llama/checkpoint.py new file mode 100644 index 000000000..2bae08a69 --- /dev/null +++ b/llama_stack/models/llama/checkpoint.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import concurrent.futures +import re +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size + + +def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]: + """Map a new MP rank to a list of old MP ranks given a change in MP size.""" + if new_mp_size % old_mp_size == 0: + # Read old MP shard and split it into smaller ones + return [new_mp_rank * old_mp_size // new_mp_size] + elif old_mp_size % new_mp_size == 0: + # Merge old MP shards into a single one + mp_factor = old_mp_size // new_mp_size + return list(range(new_mp_rank * mp_factor, (new_mp_rank + 1) * mp_factor)) + else: + raise ValueError( + f"Either old MP size or new MP size should be a multiple of the other: " + f"{old_mp_size} % {new_mp_size} != 0 and {new_mp_size} % {old_mp_size} != 0" + ) + + +def maybe_reshard_state_dict( + ckpt_paths: List[Path], + n_kv_heads: int, + moe_num_experts: Optional[int] = None, + map_location: Union[str, torch.device] = "cpu", + mmap: bool = True, +) -> Dict[str, torch.Tensor]: + if str(map_location) == "cpu": + torch.set_default_tensor_type(torch.BFloat16Tensor) + else: + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + + ckpt_paths = np.array(sorted(ckpt_paths)) + + new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank() + old_mp_size = len(ckpt_paths) + old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank) + + print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore + paths = ckpt_paths[old_mp_ranks] # type: ignore + state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths] + + if new_mp_size == old_mp_size: + return state_dicts[0] # type: ignore + + if moe_num_experts is not None: + state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts] + + print(f"Resharding {len(state_dicts)} state dicts from MP size {old_mp_size} to MP size {new_mp_size}") + return reshard_mp( + state_dicts, + size=max(new_mp_size // old_mp_size, 1), + rank=new_mp_rank % max(new_mp_size // old_mp_size, 1), + repeat_qk_qv=max(new_mp_size // n_kv_heads, 1), + ) + + +_WEIGHT_ROW_KEY = { + "feed_forward.w2", + "feed_forward.mlp.fc2", + "attention.wo", + "feed_forward.mlp.fc2_weight", + "feed_forward.w_out_shared_DF.weight", + "attn.wo.weight", + "mlp.c_proj.weight", +} +_MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"} + +_WEIGHT_COLUMN_KEY = { + "output", + "feed_forward.(w1|w3)", + "feed_forward.mlp.(fc1|fc3)", + "feed_forward.mlp.fc1_weight", + "attention.(wk|wq|wv|wqkv).weight", + "feed_forward.(w_in_shared_FD|w_swiglu_FD)", + "attn.(wk|wq|wv).weight", + "attn.(wk|wq|wv).bias", + "mlp.c_fc.weight", + "mlp.c_fc.bias", + "conv1._linear.weight", + "tok_embeddings.weight", + "vision_projection.weight", +} +_MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"} + + +def reshard_mp( + state_dicts: List[Dict[str, torch.Tensor]], + size: int, + rank: int, + repeat_qk_qv: int = 1, +) -> Dict[str, torch.Tensor]: + """ + Reshard a list of state dicts into a single state dict given a change in MP size. + If the list has more than one state dict, we concatenate the values of the same + key across all state dicts. Otherwise, we just slice it for the current MP rank. + """ + + def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor: + if len(tensors) > 1: + return torch.cat(tensors, dim=dim) + return tensors[0].chunk(size, dim=dim)[rank].clone() + + def process_key(key: str) -> torch.Tensor: + if row_regex.search(key): + return concat_or_chunk([s[key] for s in state_dicts], dim=-1) + elif column_regex.search(key): + if "w13" in key or "fc1_weight" in key: + dims = state_dicts[0][key].size() + values = [s[key].view(2, dims[0] // 2, *dims[1:]) for s in state_dicts] + return concat_or_chunk(values, dim=1).flatten(0, 1) + elif "qkv" in key: + q_dim = state_dicts[0][key.replace("qkv", "o")].size(1) + kv_dim = (state_dicts[0][key].size(0) - q_dim) // 2 + values = [s[key].split((q_dim, kv_dim, kv_dim)) for s in state_dicts] + return torch.cat([concat_or_chunk(x, dim=0) for x in zip(*values, strict=False)]) # type: ignore + elif "wk.weight" in key or "wv.weight" in key: + # Support MP > #kv_head + return concat_or_chunk([s[key].repeat(repeat_qk_qv, 1) for s in state_dicts], dim=0) + elif key == "output.bias" or key == "fc.weight": + return concat_or_chunk([s[key] for s in state_dicts], dim=0) + elif "w_" in key: + return concat_or_chunk([s[key] for s in state_dicts], dim=-2) + else: + return concat_or_chunk([s[key] for s in state_dicts], dim=0) + else: + return state_dicts[0][key].clone() + + row_keys = _WEIGHT_ROW_KEY | _MOE_WEIGHT_ROW_KEY + column_keys = _WEIGHT_COLUMN_KEY | _MOE_WEIGHT_COLUMN_KEY + + column_regex = re.compile("|".join(column_keys)) + row_regex = re.compile("|".join(row_keys)) + + output: Dict[str, torch.Tensor] = {} + with concurrent.futures.ThreadPoolExecutor() as executor: + # Note: only processes keys in the first state dict. + # Assumes keys are the same across all state dicts. + mappings = {executor.submit(process_key, key): key for key in state_dicts[0]} + for future in concurrent.futures.as_completed(mappings): + output[mappings[future]] = future.result() + return output + + +def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]: + routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY + routed_regex = re.compile("|".join(routed_keys)) + keys = list(state_dict.keys()) + for key in keys: + if routed_regex.search(key): + state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0) + return state_dict diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py index ef791da8f..48cb51005 100644 --- a/llama_stack/models/llama/datatypes.py +++ b/llama_stack/models/llama/datatypes.py @@ -4,13 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - import base64 from enum import Enum from io import BytesIO @@ -19,8 +12,6 @@ from typing import Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator from typing_extensions import Annotated -from llama_stack.schema_utils import json_schema_type, register_schema - # The goal is that these set of types are relevant for all Llama models. # That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to # the llama3 series of models. @@ -98,6 +89,29 @@ class StopReason(Enum): out_of_tokens = "out_of_tokens" +class ToolParamDefinition(BaseModel): + param_type: str + description: Optional[str] = None + required: Optional[bool] = True + default: Optional[Any] = None + + +class ToolDefinition(BaseModel): + tool_name: Union[BuiltinTool, str] + description: Optional[str] = None + parameters: Optional[Dict[str, ToolParamDefinition]] = None + + @field_validator("tool_name", mode="before") + @classmethod + def validate_field(cls, v): + if isinstance(v, str): + try: + return BuiltinTool(v) + except ValueError: + return v + return v + + class RawMediaItem(BaseModel): type: Literal["image"] = "image" data: bytes | BytesIO @@ -140,292 +154,25 @@ class RawMessage(BaseModel): tool_calls: List[ToolCall] = Field(default_factory=list) -register_schema(ToolCall) +class GenerationResult(BaseModel): + token: int + text: str + logprobs: Optional[List[float]] = None + + source: Literal["input"] | Literal["output"] + + # index within the batch + batch_idx: int + # whether generation for this item is already finished. note that tokens can + # get returned even afterwards since other items in the batch can still be generating tokens + finished: bool + # because a batch is parallel processed, useful decoding for one item can correspond to processing + # pad tokens or tokens beyond EOS for other items. we could have decided to return None for this case + # but it's more convenient to return a list of GenerationResult and filter out the ignored tokens + ignore_token: bool -@json_schema_type -class ToolParamDefinition(BaseModel): - param_type: str - description: Optional[str] = None - required: Optional[bool] = True - default: Optional[Any] = None - - -@json_schema_type -class ToolDefinition(BaseModel): - tool_name: Union[BuiltinTool, str] - description: Optional[str] = None - parameters: Optional[Dict[str, ToolParamDefinition]] = None - - @field_validator("tool_name", mode="before") - @classmethod - def validate_field(cls, v): - if isinstance(v, str): - try: - return BuiltinTool(v) - except ValueError: - return v - return v - - -@json_schema_type -class GreedySamplingStrategy(BaseModel): - type: Literal["greedy"] = "greedy" - - -@json_schema_type -class TopPSamplingStrategy(BaseModel): - type: Literal["top_p"] = "top_p" - temperature: Optional[float] = Field(..., gt=0.0) - top_p: Optional[float] = 0.95 - - -@json_schema_type -class TopKSamplingStrategy(BaseModel): - type: Literal["top_k"] = "top_k" - top_k: int = Field(..., ge=1) - - -SamplingStrategy = Annotated[ - Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy], - Field(discriminator="type"), -] -register_schema(SamplingStrategy, name="SamplingStrategy") - - -@json_schema_type -class SamplingParams(BaseModel): - """Sampling parameters. - - :param strategy: The sampling strategy. - :param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of - your prompt plus max_tokens cannot exceed the model's context length. - :param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens - based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. - :param stop: Up to 4 sequences where the API will stop generating further tokens. - The returned text will not contain the stop sequence. - """ - - strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy) - - max_tokens: Optional[int] = 0 - repetition_penalty: Optional[float] = 1.0 - stop: Optional[List[str]] = None - - -class CheckpointQuantizationFormat(Enum): - # default format - bf16 = "bf16" - - # used for enabling fp8_rowwise inference, some weights are bf16 - fp8_mixed = "fp8-mixed" - - int8 = "int8" - - int4 = "int4" - - -class ModelFamily(Enum): - llama2 = "llama2" - llama3 = "llama3" - llama3_1 = "llama3_1" - llama3_2 = "llama3_2" - llama3_3 = "llama3_3" - llama4 = "llama4" - safety = "safety" - - -class CoreModelId(Enum): - """Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)""" - - # Llama 2 family - llama2_7b = "Llama-2-7b" - llama2_13b = "Llama-2-13b" - llama2_70b = "Llama-2-70b" - llama2_7b_chat = "Llama-2-7b-chat" - llama2_13b_chat = "Llama-2-13b-chat" - llama2_70b_chat = "Llama-2-70b-chat" - - # Llama 3 family - llama3_8b = "Llama-3-8B" - llama3_70b = "Llama-3-70B" - llama3_8b_instruct = "Llama-3-8B-Instruct" - llama3_70b_instruct = "Llama-3-70B-Instruct" - - # Llama 3.1 family - llama3_1_8b = "Llama3.1-8B" - llama3_1_70b = "Llama3.1-70B" - llama3_1_405b = "Llama3.1-405B" - llama3_1_8b_instruct = "Llama3.1-8B-Instruct" - llama3_1_70b_instruct = "Llama3.1-70B-Instruct" - llama3_1_405b_instruct = "Llama3.1-405B-Instruct" - - # Llama 3.2 family - llama3_2_1b = "Llama3.2-1B" - llama3_2_3b = "Llama3.2-3B" - llama3_2_1b_instruct = "Llama3.2-1B-Instruct" - llama3_2_3b_instruct = "Llama3.2-3B-Instruct" - llama3_2_11b_vision = "Llama3.2-11B-Vision" - llama3_2_90b_vision = "Llama3.2-90B-Vision" - llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct" - llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct" - - # Llama 3.3 family - llama3_3_70b_instruct = "Llama3.3-70B-Instruct" - - # Llama 4 family - llama4_scout_17b_16e = "Llama-4-Scout-17B-16E" - llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct" - llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E" - llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct" - - # Safety models - llama_guard_3_8b = "Llama-Guard-3-8B" - llama_guard_2_8b = "Llama-Guard-2-8B" - llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision" - llama_guard_3_1b = "Llama-Guard-3-1B" - - -def is_multimodal(model_id) -> bool: - if model_id in [ - CoreModelId.llama3_2_11b_vision, - CoreModelId.llama3_2_90b_vision, - CoreModelId.llama3_2_11b_vision_instruct, - CoreModelId.llama3_2_90b_vision_instruct, - ]: - return True - else: - return False - - -def model_family(model_id) -> ModelFamily: - if model_id in [ - CoreModelId.llama2_7b, - CoreModelId.llama2_13b, - CoreModelId.llama2_70b, - CoreModelId.llama2_7b_chat, - CoreModelId.llama2_13b_chat, - CoreModelId.llama2_70b_chat, - ]: - return ModelFamily.llama2 - elif model_id in [ - CoreModelId.llama3_8b, - CoreModelId.llama3_70b, - CoreModelId.llama3_8b_instruct, - CoreModelId.llama3_70b_instruct, - ]: - return ModelFamily.llama3 - elif model_id in [ - CoreModelId.llama3_1_8b, - CoreModelId.llama3_1_70b, - CoreModelId.llama3_1_405b, - CoreModelId.llama3_1_8b_instruct, - CoreModelId.llama3_1_70b_instruct, - CoreModelId.llama3_1_405b_instruct, - ]: - return ModelFamily.llama3_1 - elif model_id in [ - CoreModelId.llama3_2_1b, - CoreModelId.llama3_2_3b, - CoreModelId.llama3_2_1b_instruct, - CoreModelId.llama3_2_3b_instruct, - CoreModelId.llama3_2_11b_vision, - CoreModelId.llama3_2_90b_vision, - CoreModelId.llama3_2_11b_vision_instruct, - CoreModelId.llama3_2_90b_vision_instruct, - ]: - return ModelFamily.llama3_2 - elif model_id in [ - CoreModelId.llama3_3_70b_instruct, - ]: - return ModelFamily.llama3_3 - elif model_id in [ - CoreModelId.llama4_scout_17b_16e, - CoreModelId.llama4_scout_17b_16e_instruct, - CoreModelId.llama4_maverick_17b_128e, - CoreModelId.llama4_maverick_17b_128e_instruct, - ]: - return ModelFamily.llama4 - elif model_id in [ - CoreModelId.llama_guard_3_8b, - CoreModelId.llama_guard_2_8b, - CoreModelId.llama_guard_3_11b_vision, - CoreModelId.llama_guard_3_1b, - ]: - return ModelFamily.safety - else: - raise ValueError(f"Unknown model family for {model_id}") - - -class Model(BaseModel): - core_model_id: CoreModelId - description: str - huggingface_repo: Optional[str] = None - recommended_sampling_params: Optional[SamplingParams] = None - arch_args: Dict[str, Any] - variant: str = "" - - quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16 - pth_file_count: int - metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) - - # silence pydantic until we remove the `model_` fields - model_config = ConfigDict(protected_namespaces=()) - - @property - def model_family(self) -> ModelFamily: - return model_family(self.core_model_id) - - # The SKU is uniquely identified by (model_id, variant) combo - def descriptor(self, shorten_default_variant: bool = True) -> str: - if not self.variant: - return self.core_model_id.value - return f"{self.core_model_id.value}:{self.variant}" - - @property - def is_instruct_model(self) -> bool: - return "instruct" in self.id.name - - # Featured models are shown in the non-exhaustive model list - @property - def is_featured(self) -> bool: - return self.model_family in [ - ModelFamily.llama3_1, - ModelFamily.llama3_2, - ModelFamily.llama3_3, - ModelFamily.llama4, - ModelFamily.safety, - ] - - @property - def max_seq_length(self) -> int: - if self.model_family == ModelFamily.llama2: - return 4096 - elif self.core_model_id == CoreModelId.llama_guard_2_8b: - return 4096 - elif self.model_family == ModelFamily.llama3: - return 8192 - elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]: - return 131072 - elif self.model_family == ModelFamily.llama3_2: - if self.quantization_format == CheckpointQuantizationFormat.int4: - return 8192 - return 131072 - elif self.model_family == ModelFamily.llama4: - if self.core_model_id in { - CoreModelId.llama4_scout_17b_16e, - CoreModelId.llama4_maverick_17b_128e, - }: - return 262144 - if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct: - return 10485760 - if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct: - return 1048576 - elif self.core_model_id in [ - CoreModelId.llama_guard_3_8b, - CoreModelId.llama_guard_3_11b_vision, - CoreModelId.llama_guard_3_1b, - ]: - return 131072 - else: - raise ValueError(f"Unknown max_seq_len for {self.core_model_id}") +class QuantizationMode(str, Enum): + none = "none" + fp8_mixed = "fp8_mixed" + int4_mixed = "int4_mixed" diff --git a/llama_stack/providers/inline/inference/meta_reference/hadamard_utils.py b/llama_stack/models/llama/hadamard_utils.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/hadamard_utils.py rename to llama_stack/models/llama/hadamard_utils.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/args.py b/llama_stack/models/llama/llama3/args.py similarity index 88% rename from llama_stack/providers/inline/inference/meta_reference/llama3/args.py rename to llama_stack/models/llama/llama3/args.py index e96eaca61..f7e4b4557 100644 --- a/llama_stack/providers/inline/inference/meta_reference/llama3/args.py +++ b/llama_stack/models/llama/llama3/args.py @@ -4,13 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - from dataclasses import dataclass from enum import Enum from typing import Optional diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py index 2862f8558..f55cd5e1c 100644 --- a/llama_stack/models/llama/llama3/chat_format.py +++ b/llama_stack/models/llama/llama3/chat_format.py @@ -4,13 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - import io import json import uuid @@ -19,7 +12,7 @@ from typing import Dict, List, Optional, Tuple from PIL import Image as PIL_Image -from llama_stack.models.llama.datatypes import ( +from ..datatypes import ( BuiltinTool, RawContent, RawMediaItem, @@ -30,7 +23,6 @@ from llama_stack.models.llama.datatypes import ( ToolCall, ToolPromptFormat, ) - from .tokenizer import Tokenizer from .tool_utils import ToolUtils diff --git a/llama_stack/models/llama/llama3/generation.py b/llama_stack/models/llama/llama3/generation.py new file mode 100644 index 000000000..ee99a07ba --- /dev/null +++ b/llama_stack/models/llama/llama3/generation.py @@ -0,0 +1,367 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import json +import os +import sys +import time +from pathlib import Path +from typing import Callable, Generator, List, Optional + +import torch +import torch.nn.functional as F +from fairscale.nn.model_parallel.initialize import ( + initialize_model_parallel, + model_parallel_is_initialized, +) +from termcolor import cprint + +from ..checkpoint import maybe_reshard_state_dict +from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat +from .args import ModelArgs +from .chat_format import ChatFormat, LLMInput +from .model import Transformer +from .multimodal.model import CrossAttentionTransformer +from .tokenizer import Tokenizer + + +class Llama3: + @staticmethod + def build( + ckpt_dir: str, + max_seq_len: int, + max_batch_size: int, + world_size: Optional[int] = None, + quantization_mode: Optional[QuantizationMode] = None, + seed: int = 1, + device: str = "cuda", + ): + device = torch.device(device) + if ( + device.type == "cuda" + and not torch.cuda.is_available() + or device.type == "xpu" + and not torch.xpu.is_available() + ): + raise RuntimeError(f"PyTorch backend for {device.type} device type is not available") + + if not torch.distributed.is_initialized(): + if device.type == "cuda": + torch.distributed.init_process_group("nccl") + else: + torch.distributed.init_process_group("gloo") + + if not model_parallel_is_initialized(): + if world_size is None: + world_size = int(os.environ.get("WORLD_SIZE", 1)) + initialize_model_parallel(world_size) + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + if device.type == "cuda": + torch.cuda.set_device(local_rank) + elif device.type == "xpu": + torch.xpu.set_device(local_rank) + + torch.manual_seed(seed) + + if local_rank > 0: + sys.stdout = open(os.devnull, "w") + + start_time = time.time() + + ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth")) + assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}" + print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})") + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + **params, + ) + tokenizer = Tokenizer.get_instance() + + state_dict = maybe_reshard_state_dict( + ckpt_paths, + n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads, + ) + + assert model_args.vocab_size == tokenizer.n_words + + def build_model(): + if model_args.vision_chunk_size > 0: + model = CrossAttentionTransformer(model_args) + model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype()) + else: + model = Transformer(model_args) + return model + + if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed: + from .quantization.loader import convert_to_quantized_model + + torch.set_default_tensor_type(torch.BFloat16Tensor) + model = build_model() + print("Loading state dict...") + model.load_state_dict(state_dict, strict=False) + print("Done...") + model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device) + torch.set_default_device(device) + else: + print(f"Setting default device to {device}") + torch.set_default_device(device) + if device.type == "cuda": + if torch.cuda.is_bf16_supported(): + torch.set_default_dtype(torch.bfloat16) + else: + torch.set_default_dtype(torch.half) + elif device.type == "xpu": + if torch.xpu.is_bf16_supported(): + torch.set_default_dtype(torch.bfloat16) + else: + torch.set_default_dtype(torch.half) + + model = build_model() + print("Loading state dict...") + model.load_state_dict(state_dict, strict=True) + model.to(device) + print("Done...") + + print(f"Loaded in {time.time() - start_time:.2f} seconds") + + return Llama3(model, tokenizer, model_args) + + def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs): + self.args = args + self.model = model + self.tokenizer = tokenizer + self.formatter = ChatFormat(tokenizer) + + @torch.inference_mode() + def generate( + self, + model_inputs: List[LLMInput], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + echo: bool = False, + print_model_input: bool = False, + logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ) -> Generator[List[GenerationResult], None, None]: + if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: + max_gen_len = self.args.max_seq_len - 1 + params = self.model.params + + print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1" + if print_model_input: + for inp in model_inputs: + tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens] + cprint( + "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n", + "red", + ) + prompt_tokens = [inp.tokens for inp in model_inputs] + + bsz = len(model_inputs) + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + min_prompt_len = min(len(t) for t in prompt_tokens) + max_prompt_len = max(len(t) for t in prompt_tokens) + + if max_prompt_len >= params.max_seq_len: + cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red") + return + + total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) + + pad_id = self.tokenizer.pad_id + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long) + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long) + if logprobs: + token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + + is_vision = not isinstance(self.model, Transformer) + if is_vision: + images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs] + mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs] + + xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( + batch_images=images, + batch_masks=mask, + total_len=total_len, + device=tokens.device, + ) + + eos_reached = torch.tensor([False] * bsz) + input_text_mask = tokens != pad_id + + if echo: + for i in range(max_prompt_len): + results = [] + for j, t in enumerate(tokens[:, i]): + results.append( + GenerationResult( + token=t.item(), + text=self.tokenizer.decode([t.item()]), + source="input", + logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None), + batch_idx=j, + finished=False, + ignore_token=t.item() == pad_id, + ) + ) + yield results + + stop_tokens = torch.tensor(self.tokenizer.stop_tokens) + + prev_pos = 0 + for cur_pos in range(min_prompt_len, total_len): + if is_vision: + position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long) + text_only_inference = all(inp.vision is None for inp in model_inputs) + logits = self.model.forward( + position_ids, + tokens, + cross_attention_masks, + full_text_row_masked_out_mask, + xattn_caches, + text_only_inference, + ) + else: + logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + + if logits_processor is not None: + logits = logits_processor(tokens[:, :cur_pos], logits) + + if temperature > 0: + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits[:, -1], dim=-1) + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) + tokens[:, cur_pos] = next_token + + target = tokens[:, prev_pos + 1 : cur_pos + 1] + if is_vision: + # the logits space (num_classes) is designed to never contain a media_token + # however our input token stream does contain them. we need to nuke them here + # or else the CUDA kernels will crash with an illegal memory access + vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256] + masks = [target.eq(t) for t in vision_tokens] + if len(masks) > 1: + mask = torch.logical_or(*masks) + else: + mask = masks[0] + target[mask] = 0 + + if logprobs: + token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( + input=logits.transpose(1, 2), + target=target, + reduction="none", + ignore_index=pad_id, + ) + eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens)) + results = [] + for idx, t in enumerate(next_token): + results.append( + GenerationResult( + token=t.item(), + text=self.tokenizer.decode([t.item()]), + source="output", + logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None), + batch_idx=idx, + finished=eos_reached[idx], + ignore_token=cur_pos < len(prompt_tokens[idx]), + ) + ) + yield results + + prev_pos = cur_pos + if all(eos_reached): + break + + def completion( + self, + contents: List[RawContent], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + echo: bool = False, + ) -> Generator[List[GenerationResult], None, None]: + model_inputs = [self.formatter.encode_content(c) for c in contents] + for result in self.generate( + model_inputs=model_inputs, + temperature=temperature, + top_p=top_p, + max_gen_len=max_gen_len, + logprobs=logprobs, + echo=echo, + ): + yield result + if all(r.finished for r in result): + break + + def chat_completion( + self, + messages_batch: List[List[RawMessage]], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, + echo: bool = False, + ) -> Generator[List[GenerationResult], None, None]: + model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch] + for result in self.generate( + model_inputs=model_inputs, + temperature=temperature, + top_p=top_p, + max_gen_len=max_gen_len, + logprobs=logprobs, + echo=echo, + ): + yield result + if all(r.finished for r in result): + break + + +def sample_top_p(probs, p): + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token diff --git a/llama_stack/models/llama/llama3/interface.py b/llama_stack/models/llama/llama3/interface.py index 2579ab6c8..8684237df 100644 --- a/llama_stack/models/llama/llama3/interface.py +++ b/llama_stack/models/llama/llama3/interface.py @@ -16,7 +16,7 @@ from typing import List, Optional from termcolor import colored -from llama_stack.models.llama.datatypes import ( +from ..datatypes import ( BuiltinTool, RawMessage, StopReason, @@ -24,7 +24,6 @@ from llama_stack.models.llama.datatypes import ( ToolDefinition, ToolPromptFormat, ) - from . import template_data from .chat_format import ChatFormat from .prompt_templates import ( diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/model.py b/llama_stack/models/llama/llama3/model.py similarity index 94% rename from llama_stack/providers/inline/inference/meta_reference/llama3/model.py rename to llama_stack/models/llama/llama3/model.py index a49167980..2562673e2 100644 --- a/llama_stack/providers/inline/inference/meta_reference/llama3/model.py +++ b/llama_stack/models/llama/llama3/model.py @@ -4,16 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. - import math from typing import Optional, Tuple @@ -29,6 +19,10 @@ from torch import nn from .args import ModelArgs +# **NOTE**: This code is not runnable without installing `torch` and `fairscale` +# dependencies. These dependencies are not part of the default dependencies +# (requirements.txt) of the `llama-models` package. + class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): @@ -111,9 +105,9 @@ class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - model_parallel_size = fs_init.get_model_parallel_world_size() - self.n_local_heads = args.n_heads // model_parallel_size - self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + world_size = fs_init.get_model_parallel_world_size() + self.n_local_heads = args.n_heads // world_size + self.n_local_kv_heads = self.n_kv_heads // world_size self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/__init__.py b/llama_stack/models/llama/llama3/multimodal/__init__.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/__init__.py rename to llama_stack/models/llama/llama3/multimodal/__init__.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/encoder_utils.py b/llama_stack/models/llama/llama3/multimodal/encoder_utils.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/encoder_utils.py rename to llama_stack/models/llama/llama3/multimodal/encoder_utils.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/image_transform.py b/llama_stack/models/llama/llama3/multimodal/image_transform.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/image_transform.py rename to llama_stack/models/llama/llama3/multimodal/image_transform.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model.py b/llama_stack/models/llama/llama3/multimodal/model.py similarity index 95% rename from llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model.py rename to llama_stack/models/llama/llama3/multimodal/model.py index 3d0d77c87..0cb18b948 100644 --- a/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model.py +++ b/llama_stack/models/llama/llama3/multimodal/model.py @@ -4,16 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - import logging import math from functools import partial @@ -180,14 +170,14 @@ class ImageAttention(nn.Module): n_heads, ): super().__init__() - model_parallel_size = fs_init.get_model_parallel_world_size() + world_size = fs_init.get_model_parallel_world_size() qkvo_replication = 1 - if model_parallel_size > 16: - qkvo_replication = model_parallel_size // 8 + if world_size > 16: + qkvo_replication = world_size // 8 self.n_kv_heads = n_heads - self.n_local_heads = n_heads * qkvo_replication // model_parallel_size - self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // model_parallel_size + self.n_local_heads = n_heads * qkvo_replication // world_size + self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // world_size self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = dim // n_heads @@ -536,16 +526,16 @@ class Attention(nn.Module): cache_v (torch.Tensor): Cached values for attention. """ super().__init__() - model_parallel_size = fs_init.get_model_parallel_world_size() + world_size = fs_init.get_model_parallel_world_size() replication_factor = 1 - if model_parallel_size > 8: - replication_factor = model_parallel_size // MP_SCALE + if world_size > 8: + replication_factor = world_size // MP_SCALE self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads self.n_kv_heads *= replication_factor - self.n_local_heads = args.n_heads // model_parallel_size - self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_local_heads = args.n_heads // world_size + self.n_local_kv_heads = self.n_kv_heads // world_size self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads self.max_seq_len = args.max_seq_len @@ -587,13 +577,11 @@ class Attention(nn.Module): self.n_local_kv_heads, self.head_dim, ) - device = next(self.parameters()).device self.register_buffer( "key_cache", torch.zeros( cache_shape, dtype=dtype, - device=device, ), persistent=False, ) @@ -602,7 +590,6 @@ class Attention(nn.Module): torch.zeros( cache_shape, dtype=dtype, - device=device, ), persistent=False, ) @@ -614,6 +601,9 @@ class Attention(nn.Module): freqs_cis: torch.Tensor, position_ids: torch.LongTensor, ): + self.key_cache = self.key_cache.to(x.device) + self.value_cache = self.value_cache.to(x.device) + xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]] bs, slen, _ = xq.shape @@ -832,10 +822,10 @@ class CrossAttention(torch.nn.Module): norm_eps: float, ): super().__init__() - self.model_parallel_size = fs_init.get_model_parallel_world_size() + self.world_size = fs_init.get_model_parallel_world_size() replication_factor = 1 - if self.model_parallel_size > 8: - replication_factor = self.model_parallel_size // MP_SCALE + if self.world_size > 8: + replication_factor = self.world_size // MP_SCALE n_kv_heads *= replication_factor assert n_heads % n_kv_heads == 0 @@ -889,10 +879,10 @@ class CrossAttention(torch.nn.Module): # trunk LLM (i.e., group query attention) -- @dubeya # local heads assert self.n_heads % self.n_kv_heads == 0 - assert self.n_heads % self.model_parallel_size == 0 - assert self.n_kv_heads % self.model_parallel_size == 0 - self.n_local_heads = self.n_heads // self.model_parallel_size - self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size + assert self.n_heads % self.world_size == 0 + assert self.n_kv_heads % self.world_size == 0 + self.n_local_heads = self.n_heads // self.world_size + self.n_local_kv_heads = self.n_kv_heads // self.world_size self.n_rep = self.n_local_heads // self.n_local_kv_heads def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: @@ -1041,7 +1031,7 @@ class CrossAttentionTransformerVision(torch.nn.Module): self.image_res = args.vision_chunk_size self.max_num_chunks = args.vision_max_num_chunks if return_intermediate is not None: - return_intermediate = [int(level) for level in return_intermediate.split(",")] + return_intermediate = [int(layer) for layer in return_intermediate.split(",")] self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim self.patch_size = 14 self.vision_encoder = VisionEncoder( @@ -1076,15 +1066,15 @@ class CrossAttentionTransformerText(torch.nn.Module): def __init__(self, args: ModelArgs) -> None: super().__init__() - self.model_parallel_size = fs_init.get_model_parallel_world_size() + self.world_size = fs_init.get_model_parallel_world_size() assert args.vocab_size > 0 self.vocab_size = args.vocab_size self.n_layers = args.n_layers self.dim = args.dim self.head_dim = args.dim // args.n_heads self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size - assert self.vocab_size % self.model_parallel_size == 0 + self.n_local_kv_heads = self.n_kv_heads // self.world_size + assert self.vocab_size % self.world_size == 0 self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x) self.pos_embeddings = None # final norm layer (not necessary for post-norm) @@ -1184,6 +1174,8 @@ class CrossAttentionTransformerText(torch.nn.Module): text_only_inference: bool = False, ): assert self.cache_is_setup, "Please set up cache before calling forward" + self.mask_cache = self.mask_cache.to(h.device) + self.freqs_cis = self.freqs_cis.to(h.device) mask = self.mask_cache.index_select(2, position_ids) freqs_cis = self.freqs_cis.index_select(0, position_ids) @@ -1212,9 +1204,8 @@ class CrossAttentionTransformerText(torch.nn.Module): output = gather_from_tensor_model_parallel_region(output) return output.float() - def setup_cache(self, max_batch_size: int, dtype=torch.bfloat16): + def setup_cache(self, max_batch_size: int, device: torch.device, dtype=torch.bfloat16): # Set up the text kv caches - device = next(self.parameters()).device ones = torch.ones( (self.max_seq_len, self.max_seq_len), dtype=torch.bool, @@ -1265,7 +1256,7 @@ class CrossAttentionTransformerText(torch.nn.Module): return ( cross_attention_masks.to(device=text_device, dtype=text_dtype), - full_text_row_masked_out_mask, + full_text_row_masked_out_mask.to(device=text_device), ) @@ -1284,14 +1275,15 @@ class CrossAttentionTransformer(torch.nn.Module): max_num_chunks=args.vision_max_num_chunks, ) - def setup_cache(self, max_batch_size: int, dtype: torch.dtype): - self.text_model.setup_cache(max_batch_size, dtype) + def setup_cache(self, max_batch_size: int, device: torch.device, dtype: torch.dtype): + self.text_model.setup_cache(max_batch_size, device, dtype) def compute_vision_tokens_masks( self, batch_images: List[List[PIL_Image.Image]], batch_masks: List[List[List[int]]], total_len: int, + device: torch.device, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: skip_vision_encoder = False @@ -1318,6 +1310,7 @@ class CrossAttentionTransformer(torch.nn.Module): image_res=self.params.vision_chunk_size, max_num_images=max_num_images, ) + stacked_images = stacked_images.to(device=device) if skip_vision_encoder: vision_tokens = torch.zeros( @@ -1330,7 +1323,7 @@ class CrossAttentionTransformer(torch.nn.Module): ), ) else: - vision_tokens = self.vision_model(stacked_images, aspect_ratios) + vision_tokens = self.vision_model(stacked_images, aspect_ratios).to(device=device) bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape) xattn_caches = torch.stack( diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/utils.py b/llama_stack/models/llama/llama3/multimodal/utils.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/utils.py rename to llama_stack/models/llama/llama3/multimodal/utils.py diff --git a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py index e03fcfc93..d4e825a22 100644 --- a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py +++ b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py @@ -15,7 +15,7 @@ import textwrap from datetime import datetime from typing import Any, List, Optional -from llama_stack.models.llama.datatypes import ( +from llama_stack.apis.inference import ( BuiltinTool, ToolDefinition, ToolParamDefinition, diff --git a/llama_stack/templates/meta-reference-quantized-gpu/__init__.py b/llama_stack/models/llama/llama3/quantization/__init__.py similarity index 74% rename from llama_stack/templates/meta-reference-quantized-gpu/__init__.py rename to llama_stack/models/llama/llama3/quantization/__init__.py index 1cfdb2c6a..756f351d8 100644 --- a/llama_stack/templates/meta-reference-quantized-gpu/__init__.py +++ b/llama_stack/models/llama/llama3/quantization/__init__.py @@ -3,5 +3,3 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -from .meta_reference import get_distribution_template # noqa: F401 diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/quantization/loader.py b/llama_stack/models/llama/llama3/quantization/loader.py similarity index 84% rename from llama_stack/providers/inline/inference/meta_reference/llama3/quantization/loader.py rename to llama_stack/models/llama/llama3/quantization/loader.py index 5109130b4..771fd02be 100644 --- a/llama_stack/providers/inline/inference/meta_reference/llama3/quantization/loader.py +++ b/llama_stack/models/llama/llama3/quantization/loader.py @@ -4,9 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. - # type: ignore import os from typing import Any, Dict, List, Optional, cast @@ -18,22 +15,15 @@ from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_regi from torch import Tensor, nn from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear -from llama_stack.apis.inference import QuantizationType -from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat -from llama_stack.models.llama.sku_list import resolve_model -from llama_stack.providers.inline.inference.meta_reference.quantize_impls import ( +from ...datatypes import QuantizationMode +from ...quantize_impls import ( Fp8ScaledWeights, ffn_swiglu, load_fp8, quantize_fp8, ) - -from ...config import MetaReferenceQuantizedInferenceConfig -from ..args import ModelArgs from ..model import Transformer, TransformerBlock - -log = get_logger(__name__, category="quantization") +from ..multimodal.model import CrossAttentionTransformer def swiglu_wrapper( @@ -44,30 +34,34 @@ def swiglu_wrapper( return reduce_from_model_parallel_region(out) +def convert_to_quantized_model( + model: Transformer | CrossAttentionTransformer, + checkpoint_dir: str, + quantization_mode: Optional[str] = None, + fp8_activation_scale_ub: Optional[float] = 1200.0, + device: Optional[torch.device] = None, +) -> Transformer | CrossAttentionTransformer: + if quantization_mode == QuantizationMode.fp8_mixed: + return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device) + elif quantization_mode == QuantizationMode.int4_mixed: + return convert_to_int4_quantized_model(model, checkpoint_dir, device) + else: + raise ValueError(f"Unsupported quantization mode: {quantization_mode}") + + def convert_to_fp8_quantized_model( model: Transformer, - config: MetaReferenceQuantizedInferenceConfig, checkpoint_dir: str, fp8_activation_scale_ub: Optional[float] = 1200.0, + device: Optional[torch.device] = None, ) -> Transformer: - if config.quantization.type == QuantizationType.bf16.value: - return model - - elif config.quantization.type != QuantizationType.fp8.value: - raise ValueError("Only FP8 quantization is supported") - - assert config.model is not None, "Model must be specified for quantized inference" - llama_model = resolve_model(config.model) - assert llama_model is not None, f"Model {config.model} not found" - # Move weights to GPU with quantization - if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value: - log.info("Loading fp8 scales...") - fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt") - assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}" + fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt") + if os.path.isfile(fp8_scales_path): + print("Loading fp8 scales...") fp8_scales = torch.load(fp8_scales_path, weights_only=True) - for block in model.layers: + for _, block in model.named_modules(): if isinstance(block, TransformerBlock): if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): continue @@ -81,8 +75,8 @@ def convert_to_fp8_quantized_model( fp8_activation_scale_ub, ) else: - log.info("Quantizing fp8 weights from bf16...") - for block in model.layers: + print("Quantizing fp8 weights from bf16...") + for _, block in model.named_modules(): if isinstance(block, TransformerBlock): if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): continue @@ -92,12 +86,12 @@ def convert_to_fp8_quantized_model( param.weight = quantize_fp8( param.weight, fp8_activation_scale_ub, - output_device=torch.device("cuda"), + output_device=device, ) for _, parameter in model.named_parameters(): if not isinstance(parameter, Fp8ScaledWeights): - parameter.data = parameter.to(device="cuda") + parameter.data = parameter.to(device=device) return model @@ -290,12 +284,12 @@ def _prepare_model_int4_weight_int8_dynamic_activation( def convert_to_int4_quantized_model( - model: Transformer, - model_args: ModelArgs, - config: MetaReferenceQuantizedInferenceConfig, -) -> Transformer: + model: Transformer | CrossAttentionTransformer, + checkpoint_dir: str, + device: Optional[torch.device] = None, +) -> Transformer | CrossAttentionTransformer: """Convert the model to int4 quantized model.""" - + model_args = model.params assert model_args.quantization_args is not None, "Quantization args must be specified." quantization_args = model_args.quantization_args if quantization_args.scheme is None: @@ -319,5 +313,4 @@ def convert_to_int4_quantized_model( lora_scale = model_args.lora_args.scale _prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - return cast(Transformer, model.to(device)) + return cast(Transformer | CrossAttentionTransformer, model.to(device=device)) diff --git a/llama_stack/models/llama/llama3/template_data.py b/llama_stack/models/llama/llama3/template_data.py index 076b4adb4..efca8397e 100644 --- a/llama_stack/models/llama/llama3/template_data.py +++ b/llama_stack/models/llama/llama3/template_data.py @@ -12,8 +12,7 @@ # the top-level of this source tree. -from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall - +from ..datatypes import BuiltinTool, StopReason, ToolCall from .prompt_templates import ( BuiltinToolGenerator, JsonCustomToolGenerator, diff --git a/llama_stack/models/llama/llama3/tokenizer.py b/llama_stack/models/llama/llama3/tokenizer.py index b240fa246..d3cc4fc07 100644 --- a/llama_stack/models/llama/llama3/tokenizer.py +++ b/llama_stack/models/llama/llama3/tokenizer.py @@ -4,16 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. - import os from logging import getLogger from pathlib import Path diff --git a/llama_stack/models/llama/llama3/tool_utils.py b/llama_stack/models/llama/llama3/tool_utils.py index 71018898c..fc8287eb6 100644 --- a/llama_stack/models/llama/llama3/tool_utils.py +++ b/llama_stack/models/llama/llama3/tool_utils.py @@ -16,7 +16,8 @@ import re from typing import Optional, Tuple from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat + +from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat logger = get_logger(name=__name__, category="inference") diff --git a/llama_stack/models/llama/llama3_2/__init__.py b/llama_stack/models/llama/llama3_2/__init__.py index 38ee47d66..756f351d8 100644 --- a/llama_stack/models/llama/llama3_2/__init__.py +++ b/llama_stack/models/llama/llama3_2/__init__.py @@ -3,10 +3,3 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. diff --git a/llama_stack/models/llama/llama3_2/prompts_text.py b/llama_stack/models/llama/llama3_2/prompts_text.py index 7bc7e3219..7a1f9887c 100644 --- a/llama_stack/models/llama/llama3_2/prompts_text.py +++ b/llama_stack/models/llama/llama3_2/prompts_text.py @@ -4,12 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. import json import textwrap diff --git a/llama_stack/models/llama/llama3_2/prompts_vision.py b/llama_stack/models/llama/llama3_2/prompts_vision.py index b1ede418b..b0f11cab6 100644 --- a/llama_stack/models/llama/llama3_2/prompts_vision.py +++ b/llama_stack/models/llama/llama3_2/prompts_vision.py @@ -4,13 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - import textwrap from pathlib import Path diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/args.py b/llama_stack/models/llama/llama4/args.py similarity index 91% rename from llama_stack/providers/inline/inference/meta_reference/llama4/args.py rename to llama_stack/models/llama/llama4/args.py index 046448ef6..6d7c1d409 100644 --- a/llama_stack/providers/inline/inference/meta_reference/llama4/args.py +++ b/llama_stack/models/llama/llama4/args.py @@ -4,13 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - from enum import Enum from typing import Optional diff --git a/llama_stack/models/llama/llama4/chat_format.py b/llama_stack/models/llama/llama4/chat_format.py index c873012d6..160bb00f8 100644 --- a/llama_stack/models/llama/llama4/chat_format.py +++ b/llama_stack/models/llama/llama4/chat_format.py @@ -13,7 +13,7 @@ import torch from PIL import Image as PIL_Image # TODO: either fork these or move them to the common package -from llama_stack.models.llama.datatypes import ( +from ..datatypes import ( BuiltinTool, RawContent, RawMediaItem, @@ -24,16 +24,10 @@ from llama_stack.models.llama.datatypes import ( ToolCall, ToolPromptFormat, ) -from llama_stack.models.llama.llama3.tool_utils import ToolUtils -from llama_stack.providers.inline.inference.meta_reference.llama4.args import VisionArgs -from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import ( - LLMInput, -) -from llama_stack.providers.inline.inference.meta_reference.llama4.preprocess import ( - ResizeNormalizeImageTransform, - VariableSizeImageTransform, -) - +from ..llama3.tool_utils import ToolUtils +from .args import VisionArgs +from .datatypes import LLMInput +from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform from .tokenizer import Tokenizer @@ -54,7 +48,7 @@ class TransformedImage: aspect_ratio: Tuple[int, int] -def convert_rgba_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image: +def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image: if image.mode == "RGBA": image.load() # for png.split() new_img = PIL_Image.new("RGB", image.size, bg) @@ -171,7 +165,7 @@ class ChatFormat: bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data image = PIL_Image.open(bytes_io) - image = convert_rgba_to_rgb(image) + image = convert_image_to_rgb(image) image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks) if image_tiles.shape[0] > 1: diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/datatypes.py b/llama_stack/models/llama/llama4/datatypes.py similarity index 85% rename from llama_stack/providers/inline/inference/meta_reference/llama4/datatypes.py rename to llama_stack/models/llama/llama4/datatypes.py index bb1c19a12..27174db63 100644 --- a/llama_stack/providers/inline/inference/meta_reference/llama4/datatypes.py +++ b/llama_stack/models/llama/llama4/datatypes.py @@ -4,13 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - from dataclasses import dataclass from typing import List, Optional, Union diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/ffn.py b/llama_stack/models/llama/llama4/ffn.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama4/ffn.py rename to llama_stack/models/llama/llama4/ffn.py diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/generation.py b/llama_stack/models/llama/llama4/generation.py similarity index 72% rename from llama_stack/providers/inline/inference/meta_reference/llama4/generation.py rename to llama_stack/models/llama/llama4/generation.py index de900ce8d..7a4087c8f 100644 --- a/llama_stack/providers/inline/inference/meta_reference/llama4/generation.py +++ b/llama_stack/models/llama/llama4/generation.py @@ -10,40 +10,28 @@ import json import os import sys import time -from enum import Enum from pathlib import Path from typing import Callable, Generator, List, Optional import torch import torch.nn.functional as F from fairscale.nn.model_parallel.initialize import ( - get_model_parallel_rank, initialize_model_parallel, model_parallel_is_initialized, ) from termcolor import cprint -from llama_stack.models.llama.llama4.chat_format import ( - ChatFormat, - RawContent, - RawMessage, -) -from llama_stack.models.llama.llama4.tokenizer import Tokenizer - -from ..common import TokenResult +from ..checkpoint import maybe_reshard_state_dict +from ..datatypes import GenerationResult, QuantizationMode from .args import ModelArgs +from .chat_format import ChatFormat, RawContent, RawMessage from .datatypes import LLMInput, MaskedEmbedding, TransformerInput from .model import Transformer +from .tokenizer import Tokenizer torch.serialization.add_safe_globals([io.BytesIO, codecs.encode]) -class QuantizationMode(str, Enum): - none = "none" - fp8_mixed = "fp8_mixed" - int4_mixed = "int4_mixed" - - class Llama4: @staticmethod def build( @@ -51,7 +39,7 @@ class Llama4: max_seq_len: int, max_batch_size: int, world_size: Optional[int] = None, - quantization_mode: Optional[str] = None, + quantization_mode: Optional[QuantizationMode] = None, seed: int = 1, ): if not torch.distributed.is_initialized(): @@ -72,11 +60,9 @@ class Llama4: start_time = time.time() - checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) - assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" - assert world_size == len(checkpoints), ( - f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}" - ) + ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth")) + assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}" + print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})") with open(Path(ckpt_dir) / "params.json", "r") as f: params = json.loads(f.read()) @@ -93,10 +79,11 @@ class Llama4: assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch" print("Model args:\n", model_args.model_dump_json(indent=2)) - ckpt_path = checkpoints[get_model_parallel_rank()] - print(f"Loading checkpoint from {ckpt_dir}...") - with open(ckpt_path, "rb") as f: - checkpoint = torch.load(f, map_location="cpu", weights_only=True) + state_dict = maybe_reshard_state_dict( + ckpt_paths, + n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads, + moe_num_experts=model_args.moe_args.num_experts, + ) print("Loaded checkpoint") if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed: from .quantization.loader import convert_to_quantized_model @@ -104,9 +91,9 @@ class Llama4: torch.set_default_tensor_type(torch.BFloat16Tensor) model = Transformer(model_args) print("Loading state dict...") - model.load_state_dict(checkpoint, strict=False) + model.load_state_dict(state_dict, strict=False) print("Done...") - model = convert_to_quantized_model(model, ckpt_dir) + model = convert_to_quantized_model(model, ckpt_dir, quantization_mode) else: if torch.cuda.is_bf16_supported(): torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) @@ -115,7 +102,7 @@ class Llama4: model = Transformer(model_args) print("Loading state dict...") - model.load_state_dict(checkpoint, strict=False) + model.load_state_dict(state_dict, strict=False) print("Done...") print(f"Loaded in {time.time() - start_time:.2f} seconds") @@ -130,7 +117,7 @@ class Llama4: @torch.inference_mode() def generate( self, - llm_input: LLMInput, + llm_inputs: List[LLMInput], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, @@ -138,22 +125,20 @@ class Llama4: echo: bool = False, print_model_input: bool = False, logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - ) -> Generator: + ) -> Generator[List[GenerationResult], None, None]: if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len: max_gen_len = self.model.args.max_seq_len - 1 params = self.model.args print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1" - if print_model_input and get_model_parallel_rank() == 0: - tokens_to_print = list(llm_input.tokens) - cprint( - "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n", - "red", - ) - prompt_tokens = [llm_input.tokens] + if print_model_input: + cprint("Input to model:\n", "yellow") + for inp in llm_inputs: + cprint(self.tokenizer.decode(inp.tokens), "grey") + prompt_tokens = [inp.tokens for inp in llm_inputs] - bsz = 1 + bsz = len(llm_inputs) assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) min_prompt_len = min(len(t) for t in prompt_tokens) @@ -176,24 +161,33 @@ class Llama4: input_text_mask = tokens != pad_id if echo: - for i, t in enumerate(llm_input.tokens): - yield TokenResult( - token=t, - text=self.tokenizer.decode([t]), - logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None), - ) + for i in range(max_prompt_len): + results = [] + for j, t in enumerate(tokens[:, i]): + results.append( + GenerationResult( + token=t.item(), + text=self.tokenizer.decode([t.item()]), + source="input", + logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None), + batch_idx=j, + finished=False, + ignore_token=t.item() == pad_id, + ) + ) + yield results stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda") prev_pos = 0 for cur_pos in range(min_prompt_len, total_len): image_embedding = None - if prev_pos == 0 and llm_input.images is not None and len(llm_input.images) > 0: + if prev_pos == 0 and any(inp.images is not None and len(inp.images) > 0 for inp in llm_inputs): image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"] image_mask = image_mask.unsqueeze(-1) h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos]) - image_batch = [llm_input.images] + image_batch = [inp.images if inp.images is not None else [] for inp in llm_inputs] image_embedding = MaskedEmbedding( embedding=self.model.vision_embeddings(image_batch, image_mask, h), mask=image_mask, @@ -229,11 +223,21 @@ class Llama4: ignore_index=pad_id, ) eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens)) - yield TokenResult( - token=next_token[0].item(), - text=self.tokenizer.decode(next_token.tolist()), - logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None), - ) + + results = [] + for idx, t in enumerate(next_token): + results.append( + GenerationResult( + token=t.item(), + text=self.tokenizer.decode([t.item()]), + source="output", + logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None), + batch_idx=idx, + finished=eos_reached[idx], + ignore_token=cur_pos < len(prompt_tokens[idx]), + ) + ) + yield results prev_pos = cur_pos if all(eos_reached): @@ -241,68 +245,47 @@ class Llama4: def completion( self, - content: RawContent, + contents: List[RawContent], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, echo: bool = False, - ) -> Generator: - llm_input = self.formatter.encode_content(content) + ) -> Generator[List[GenerationResult], None, None]: + llm_inputs = [self.formatter.encode_content(c) for c in contents] for result in self.generate( - llm_input=llm_input, + llm_inputs=llm_inputs, temperature=temperature, top_p=top_p, max_gen_len=max_gen_len, logprobs=logprobs, echo=echo, ): - if result.token in self.tokenizer.stop_tokens: - break yield result + if all(r.finished for r in result): + break def chat_completion( self, - messages: List[RawMessage], + messages_batch: List[List[RawMessage]], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, echo: bool = False, - ) -> Generator: - llm_input = self.formatter.encode_dialog_prompt(messages) + ) -> Generator[List[GenerationResult], None, None]: + llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch] for result in self.generate( - llm_input=llm_input, + llm_inputs=llm_inputs, temperature=temperature, top_p=top_p, max_gen_len=max_gen_len, logprobs=logprobs, echo=echo, ): - if result.token in self.tokenizer.stop_tokens: - break yield result - - def chat_completion_raw( - self, - messages: List[RawMessage], - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - ): - llm_input = self.formatter.encode_dialog_prompt(messages) - output_tokens = [] - for result in self.generate( - llm_input=llm_input, - temperature=temperature, - top_p=top_p, - max_gen_len=max_gen_len, - logprobs=logprobs, - ): - output_tokens.append(result.token) - - return llm_input.tokens, output_tokens + if all(r.finished for r in result): + break def sample_top_p(probs, p): diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/model.py b/llama_stack/models/llama/llama4/model.py similarity index 97% rename from llama_stack/providers/inline/inference/meta_reference/llama4/model.py rename to llama_stack/models/llama/llama4/model.py index a35d6857f..08fac7714 100644 --- a/llama_stack/providers/inline/inference/meta_reference/llama4/model.py +++ b/llama_stack/models/llama/llama4/model.py @@ -4,16 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. - import math from typing import Any, Dict, List, Optional, Tuple @@ -184,7 +174,6 @@ class Attention(nn.Module): self.head_dim, ) ).cuda() - self.qk_norm = None if self.use_qk_norm: self.qk_norm = L2Norm(args.norm_eps) diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/moe.py b/llama_stack/models/llama/llama4/moe.py similarity index 87% rename from llama_stack/providers/inline/inference/meta_reference/llama4/moe.py rename to llama_stack/models/llama/llama4/moe.py index 8cecab7dd..2ce49e915 100644 --- a/llama_stack/providers/inline/inference/meta_reference/llama4/moe.py +++ b/llama_stack/models/llama/llama4/moe.py @@ -100,31 +100,21 @@ class Experts(nn.Module): class MoE(torch.nn.Module): """ - This EC implementation is modified from the original EC module. - We refactored the token permutation and unpermutation logic and added support to tp and dp2ep sharding. - This module supports 3 sharding methods of the experts: - - tp: each TP rank has n_experts experts. Experts are sharded following the conventional row/column-parallel TP sharding. - - tp2ep: each TP rank has n_experts/tp experts. Experts are not sharded. - - dp2ep: each EP rank has n_experts/ep experts. Experts are sharded following the row/column-parallel TP sharding. Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor. Several commonly used annotations include: - a: bsz*slen - E: number of experts - e: number of local experts per ep (n_experts/ep) - - et: number of local experts per tp (n_experts/tp) - D: hidden dimension - d: D/tp - F: model dimension - - f: F/tp (used in column/row-parallel linear) - G: number of tokens per expert (a * capacity_factor / E) - g: number of tokens per expert per TP rank (i.e., G/TP) - - GG: G*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=False) - - gg: g*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=True) Examples: x_aD [a, D] routed_in_etG_D [et*G, D] - x_eGGD: [e, GG, D] + x_eGD: [e, G, D] """ def __init__( @@ -207,13 +197,13 @@ class MoE(torch.nn.Module): routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1) out_aD = self.shared_expert(x_aD) - routed_out_egg_D = self.experts(routed_in_EG_D.detach()) + routed_out_eg_D = self.experts(routed_in_EG_D.detach()) router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D) out_aD.scatter_add_( dim=0, index=router_indices_EG_D, - src=routed_out_egg_D.view(-1, D), + src=routed_out_eg_D.view(-1, D), ) out_aD = reduce_from_model_parallel_region(out_aD) return out_aD.view(-1, slen, D) diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/preprocess.py b/llama_stack/models/llama/llama4/preprocess.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama4/preprocess.py rename to llama_stack/models/llama/llama4/preprocess.py diff --git a/llama_stack/models/llama/llama4/prompts.py b/llama_stack/models/llama/llama4/prompts.py index 97f573ef8..13b96359a 100644 --- a/llama_stack/models/llama/llama4/prompts.py +++ b/llama_stack/models/llama/llama4/prompts.py @@ -4,20 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - import textwrap from io import BytesIO from pathlib import Path from typing import List -from llama_stack.models.llama.datatypes import RawMediaItem, RawMessage, RawTextItem -from llama_stack.models.llama.prompt_format import ( +from ..datatypes import RawMediaItem, RawMessage, RawTextItem +from ..prompt_format import ( Llama4UseCase, TextCompletionContent, UseCase, diff --git a/llama_stack/models/llama/llama4/quantization/__init__.py b/llama_stack/models/llama/llama4/quantization/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/models/llama/llama4/quantization/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/quantization/loader.py b/llama_stack/models/llama/llama4/quantization/loader.py similarity index 70% rename from llama_stack/providers/inline/inference/meta_reference/llama4/quantization/loader.py rename to llama_stack/models/llama/llama4/quantization/loader.py index 69aa309fa..b50432896 100644 --- a/llama_stack/providers/inline/inference/meta_reference/llama4/quantization/loader.py +++ b/llama_stack/models/llama/llama4/quantization/loader.py @@ -6,20 +6,29 @@ import logging import os -from typing import Optional +from typing import Callable, Optional import torch from fairscale.nn.model_parallel.initialize import get_model_parallel_rank -from torch import Tensor +from torch import Tensor, nn from torch.nn import functional as F -from ..generation import QuantizationMode +from ...datatypes import QuantizationMode from ..model import Transformer, TransformerBlock from ..moe import MoE log = logging.getLogger(__name__) +def swiglu_wrapper_no_reduce( + self, + x: Tensor, +): + from ...quantize_impls import ffn_swiglu + + return ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight) + + def experts_batched_swiglu_wrapper( self, x: Tensor, # (e, g, D) @@ -51,24 +60,30 @@ def convert_to_quantized_model( rank = get_model_parallel_rank() + def should_quantize_block(block: nn.Module) -> bool: + if not isinstance(block, TransformerBlock): + return False + + is_moe = isinstance(block.feed_forward, MoE) + if quantization_mode == QuantizationMode.fp8_mixed: + # skip quantization on first and last layers + return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1)) + + return is_moe + use_rich_progress = use_rich_progress and rank == 0 - progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model) + progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block) if quantization_mode == QuantizationMode.int4_mixed: int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt") - int4_zero_points_path = os.path.join(checkpoint_dir, f"int4_zero_points_{rank}.pt") if os.path.isfile(int4_scales_path): log_status(f"Rank {rank}: Loading int4 scales") int4_scales = torch.load(int4_scales_path, weights_only=True) - int4_zero_points = torch.load(int4_zero_points_path, weights_only=True) def apply_quantization(key, weight): scale = int4_scales[key] - zero_point = int4_zero_points[key] return load_int4( weight, scale, - zero_point, - fp8_activation_scale_ub, output_device=torch.device("cuda"), ) @@ -77,6 +92,7 @@ def convert_to_quantized_model( def apply_quantization(_, weight): return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda")) + else: fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt") if os.path.isfile(fp8_scales_path): @@ -104,33 +120,38 @@ def convert_to_quantized_model( progress.start() for _, block in model.named_modules(): - if isinstance(block, TransformerBlock): - # Skip quantization on first and last layers - if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): - continue + if not should_quantize_block(block): + continue - # Skip quantization on dense layers - if not isinstance(block.feed_forward, MoE): - continue + update_status(f"Rank {rank} - Layer {block.layer_id}") - update_status(f"Rank {rank} - Layer {block.layer_id}") + # Quantize only routed experts, not shared + prefix = f"layers.{block.layer_id}.feed_forward" + moe = block.feed_forward + moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts) - # Quantize only routed experts, not shared - prefix = f"layers.{block.layer_id}.feed_forward" - moe = block.feed_forward - moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts) + for key in ("w1", "w3", "w2"): + param = getattr(moe.experts, key) + update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}") + setattr( + moe.experts, + key, + apply_quantization( + f"{prefix}.experts.{key}", + param.transpose(1, 2).contiguous(), + ), + ) + if quantization_mode == QuantizationMode.int4_mixed: + # Quantize shared experts + moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert) for key in ("w1", "w3", "w2"): - param = getattr(moe.experts, key) - update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}") - setattr( - moe.experts, - key, - apply_quantization(f"{prefix}.experts.{key}", param.transpose(1, 2).contiguous()), - ) + param = getattr(moe.shared_expert, key) + update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}") + param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight) - processed_blocks += 1 - update_status(message=None, completed=processed_blocks) + processed_blocks += 1 + update_status(message=None, completed=processed_blocks) update_status(f"Rank {rank} - Moving parameters to CUDA") @@ -149,7 +170,12 @@ def convert_to_quantized_model( # fp8/int4 loading can be very slow so we add progress bars to make life slightly better -def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer): +def logging_callbacks( + use_rich_progress: bool, + rank: int, + model: Transformer, + should_quantize_block: Callable[[nn.Module], bool], +): console = None if use_rich_progress: from rich.console import Console @@ -162,15 +188,7 @@ def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer): elif rank == 0: # Only log from rank 0 for non-rich logging log.info(message) - total_blocks = sum( - 1 - for _, block in model.named_modules() - if ( - isinstance(block, TransformerBlock) - and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1)) - and isinstance(block.feed_forward, MoE) - ) - ) + total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block)) progress = None if use_rich_progress: from rich.progress import ( diff --git a/llama_stack/models/llama/llama4/tokenizer.py b/llama_stack/models/llama/llama4/tokenizer.py index c1347daca..4d271e5fd 100644 --- a/llama_stack/models/llama/llama4/tokenizer.py +++ b/llama_stack/models/llama/llama4/tokenizer.py @@ -4,9 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. - import os from logging import getLogger from pathlib import Path @@ -59,8 +56,6 @@ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [ "<|text_post_train_reserved_special_token_3|>", "<|text_post_train_reserved_special_token_4|>", "<|text_post_train_reserved_special_token_5|>", - "<|python_start|>", - "<|python_end|>", "<|finetune_right_pad|>", ] + get_reserved_special_tokens( "text_post_train", 61, 6 @@ -85,8 +80,23 @@ LLAMA4_VISION_SPECIAL_TOKENS = [ "vision", 1041, 7 ) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|> +# 201134, ..., 201143 +LLAMA4_REASONING_SPECIAL_TOKENS = [ + "<|reasoning_reserved_special_token_0|>", + "<|reasoning_reserved_special_token_1|>", + "<|reasoning_reserved_special_token_2|>", + "<|reasoning_reserved_special_token_3|>", + "<|reasoning_reserved_special_token_4|>", + "<|reasoning_reserved_special_token_5|>", + "<|reasoning_reserved_special_token_6|>", + "<|reasoning_reserved_special_token_7|>", + "<|reasoning_thinking_start|>", + "<|reasoning_thinking_end|>", +] -LLAMA4_SPECIAL_TOKENS = LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS +LLAMA4_SPECIAL_TOKENS = ( + LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS + LLAMA4_REASONING_SPECIAL_TOKENS +) BASIC_SPECIAL_TOKENS = [ "<|begin_of_text|>", @@ -155,6 +165,9 @@ class Tokenizer: self.eot_id: int = self.special_tokens["<|eot|>"] self.eom_id: int = self.special_tokens["<|eom|>"] + self.thinking_start_id: int = self.special_tokens["<|reasoning_thinking_start|>"] + self.thinking_end_id: int = self.special_tokens["<|reasoning_thinking_end|>"] + self.stop_tokens = [ self.eos_id, self.special_tokens["<|eom|>"], diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/vision/embedding.py b/llama_stack/models/llama/llama4/vision/embedding.py similarity index 96% rename from llama_stack/providers/inline/inference/meta_reference/llama4/vision/embedding.py rename to llama_stack/models/llama/llama4/vision/embedding.py index 73b29cbef..ed7659a73 100644 --- a/llama_stack/providers/inline/inference/meta_reference/llama4/vision/embedding.py +++ b/llama_stack/models/llama/llama4/vision/embedding.py @@ -4,13 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - import math from typing import Any, Callable, Dict, List diff --git a/llama_stack/providers/inline/inference/meta_reference/llama4/vision/encoder.py b/llama_stack/models/llama/llama4/vision/encoder.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/llama4/vision/encoder.py rename to llama_stack/models/llama/llama4/vision/encoder.py diff --git a/llama_stack/models/llama/prompt_format.py b/llama_stack/models/llama/prompt_format.py index 695c0bf74..edb34620c 100644 --- a/llama_stack/models/llama/prompt_format.py +++ b/llama_stack/models/llama/prompt_format.py @@ -28,9 +28,6 @@ from llama_stack.models.llama.datatypes import ( ToolPromptFormat, ) from llama_stack.models.llama.llama4.tokenizer import Tokenizer -from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import ( - LLMInput, -) from .llama3.interface import LLama31Interface from .llama3.template_data import ( @@ -76,21 +73,22 @@ class UseCase(BaseModel): text += dialog text += "\n\n" continue - - elif isinstance(dialog, TextCompletionContent): - input_tokens, output_tokens = generator.text_completion_raw( - dialog.content, - temperature=0.1, - top_p=0.95, - max_gen_len=64, - ) else: - input_tokens, output_tokens = generator.chat_completion_raw( - dialog, - temperature=0.0, - top_p=0.95, - max_gen_len=self.max_gen_len, + batch = [dialog] + method = ( + generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion ) + input_tokens = [] + output_tokens = [] + for token_results in method(batch, echo=True, temperature=0.1, top_p=0.95): + result = token_results[0] + if result.source == "input": + input_tokens.append(result.token) + else: + output_tokens.append(result.token) + + if result.finished: + break text += "##### Input Prompt Format\n" # FIXME: This is added to undo the hack in chat_formatter where @@ -126,27 +124,27 @@ class Llama4UseCase(UseCase): text = "" tokenizer = Tokenizer.get_instance() - temperature = 0.0 for dialog in self.dialogs: if isinstance(dialog, str): text += dialog text += "\n\n" continue - - elif isinstance(dialog, TextCompletionContent): - # TODO pass the raw input and do the encoding in the text completion function - input_tokens = tokenizer.encode(dialog.content, bos=True, eos=False) - llm_input = LLMInput(tokens=input_tokens) - output_tokens, decoded_tokens, token_logprobs = generator.text_completion_raw( - llm_input, temperature=temperature, max_gen_len=self.max_gen_len - ) - else: - input_tokens, output_tokens = generator.chat_completion_raw( - dialog, - temperature=temperature, - max_gen_len=self.max_gen_len, + batch = [dialog] + method = ( + generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion ) + input_tokens = [] + output_tokens = [] + for token_results in method(batch, echo=True, temperature=0.0): + result = token_results[0] + if result.source == "input": + input_tokens.append(result.token) + else: + output_tokens.append(result.token) + + if result.finished: + break text += "##### Input Prompt Format\n" text += _code_block(tokenizer.decode(input_tokens)) diff --git a/llama_stack/providers/inline/inference/meta_reference/quantize_impls.py b/llama_stack/models/llama/quantize_impls.py similarity index 100% rename from llama_stack/providers/inline/inference/meta_reference/quantize_impls.py rename to llama_stack/models/llama/quantize_impls.py diff --git a/llama_stack/models/llama/sku_list.py b/llama_stack/models/llama/sku_list.py index dd3144bb0..513481831 100644 --- a/llama_stack/models/llama/sku_list.py +++ b/llama_stack/models/llama/sku_list.py @@ -4,24 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - from dataclasses import dataclass from functools import lru_cache from typing import List, Optional -from .datatypes import ( +from .sku_types import ( CheckpointQuantizationFormat, CoreModelId, Model, ModelFamily, - SamplingParams, - TopPSamplingStrategy, ) LLAMA2_VOCAB_SIZE = 32000 @@ -47,15 +38,6 @@ def all_registered_models() -> List[Model]: ) -def recommended_sampling_params() -> SamplingParams: - return SamplingParams( - strategy=TopPSamplingStrategy( - temperature=1.0, - top_p=0.9, - ) - ) - - def llama2_family() -> List[Model]: return [ *llama2_base_models(), @@ -150,7 +132,6 @@ def llama2_base_models() -> List[Model]: core_model_id=CoreModelId.llama2_7b, description="Llama 2 7b model", huggingface_repo="meta-llama/Llama-2-7b", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 4096, "n_layers": 32, @@ -169,7 +150,6 @@ def llama2_base_models() -> List[Model]: core_model_id=CoreModelId.llama2_13b, description="Llama 2 13b model", huggingface_repo="meta-llama/Llama-2-13b", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 5120, "n_layers": 40, @@ -188,7 +168,6 @@ def llama2_base_models() -> List[Model]: core_model_id=CoreModelId.llama2_70b, description="Llama 2 70b model", huggingface_repo="meta-llama/Llama-2-70b", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -230,7 +209,6 @@ def llama3_base_models() -> List[Model]: core_model_id=CoreModelId.llama3_70b, description="Llama 3 70b model", huggingface_repo="meta-llama/Llama-3-70B", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -254,7 +232,6 @@ def llama3_1_base_models() -> List[Model]: core_model_id=CoreModelId.llama3_1_8b, description="Llama 3.1 8b model", huggingface_repo="meta-llama/Llama-3.1-8B", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 4096, "n_layers": 32, @@ -273,7 +250,6 @@ def llama3_1_base_models() -> List[Model]: core_model_id=CoreModelId.llama3_1_70b, description="Llama 3.1 70b model", huggingface_repo="meta-llama/Llama-3.1-70B", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -293,7 +269,6 @@ def llama3_1_base_models() -> List[Model]: variant="bf16-mp8", description="Llama 3.1 405b model (BF16 weights)", huggingface_repo="meta-llama/Llama-3.1-405B", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 16384, "n_layers": 126, @@ -313,7 +288,6 @@ def llama3_1_base_models() -> List[Model]: description="Llama 3.1 405b model (FP8 quantized)", huggingface_repo="meta-llama/Llama-3.1-405B-FP8", quantization_format=CheckpointQuantizationFormat.fp8_mixed, - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 16384, "n_layers": 126, @@ -333,7 +307,6 @@ def llama3_1_base_models() -> List[Model]: variant="bf16-mp16", description="Llama 3.1 405b model (BF16 weights for mp16)", huggingface_repo="meta-llama/Llama-3.1-405B", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 16384, "n_layers": 126, @@ -357,7 +330,6 @@ def llama3_2_base_models() -> List[Model]: core_model_id=CoreModelId.llama3_2_1b, description="Llama 3.2 1b model", huggingface_repo="meta-llama/Llama-3.2-1B", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 2048, "n_layers": 16, @@ -376,7 +348,6 @@ def llama3_2_base_models() -> List[Model]: core_model_id=CoreModelId.llama3_2_3b, description="Llama 3.2 3b model", huggingface_repo="meta-llama/Llama-3.2-3B", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 3072, "n_layers": 28, @@ -395,7 +366,6 @@ def llama3_2_base_models() -> List[Model]: core_model_id=CoreModelId.llama3_2_11b_vision, description="Llama 3.2 11b vision model", huggingface_repo="meta-llama/Llama-3.2-11B-Vision", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 4096, "n_layers": 32, @@ -417,7 +387,6 @@ def llama3_2_base_models() -> List[Model]: core_model_id=CoreModelId.llama3_2_90b_vision, description="Llama 3.2 90b vision model", huggingface_repo="meta-llama/Llama-3.2-90B-Vision", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -444,7 +413,6 @@ def llama2_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama2_7b_chat, description="Llama 2 7b chat model", huggingface_repo="meta-llama/Llama-2-7b-chat", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 4096, "n_layers": 32, @@ -463,7 +431,6 @@ def llama2_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama2_13b_chat, description="Llama 2 13b chat model", huggingface_repo="meta-llama/Llama-2-13b-chat", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 5120, "n_layers": 40, @@ -482,7 +449,6 @@ def llama2_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama2_70b_chat, description="Llama 2 70b chat model", huggingface_repo="meta-llama/Llama-2-70b-chat", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -506,7 +472,6 @@ def llama3_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_8b_instruct, description="Llama 3 8b instruct model", huggingface_repo="meta-llama/Llama-3-8B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 4096, "n_layers": 32, @@ -525,7 +490,6 @@ def llama3_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_70b_instruct, description="Llama 3 70b instruct model", huggingface_repo="meta-llama/Llama-3-70B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -549,7 +513,6 @@ def llama3_1_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_1_8b_instruct, description="Llama 3.1 8b instruct model", huggingface_repo="meta-llama/Llama-3.1-8B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 4096, "n_layers": 32, @@ -568,7 +531,6 @@ def llama3_1_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_1_70b_instruct, description="Llama 3.1 70b instruct model", huggingface_repo="meta-llama/Llama-3.1-70B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -588,7 +550,6 @@ def llama3_1_instruct_models() -> List[Model]: variant="bf16-mp8", description="Llama 3.1 405b instruct model (BF16 weights)", huggingface_repo="meta-llama/Llama-3.1-405B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 16384, "n_layers": 126, @@ -608,7 +569,6 @@ def llama3_1_instruct_models() -> List[Model]: description="Llama 3.1 405b instruct model (FP8 quantized)", huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8", quantization_format=CheckpointQuantizationFormat.fp8_mixed, - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 16384, "n_layers": 126, @@ -628,7 +588,6 @@ def llama3_1_instruct_models() -> List[Model]: variant="bf16-mp16", description="Llama 3.1 405b instruct model (BF16 weights for mp16)", huggingface_repo="meta-llama/Llama-3.1-405B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 16384, "n_layers": 126, @@ -684,7 +643,6 @@ def llama3_2_quantized_models() -> List[Model]: quantization_format=CheckpointQuantizationFormat.int4, description="Llama 3.2 1b INT4 quantized LoRA", huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8", - recommended_sampling_params=recommended_sampling_params(), arch_args={ **arch_args_1b(), "quantization_args": { @@ -703,7 +661,6 @@ def llama3_2_quantized_models() -> List[Model]: quantization_format=CheckpointQuantizationFormat.int4, description="Llama 3.2 1b INT4 quantized SpinQuant", huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8", - recommended_sampling_params=recommended_sampling_params(), arch_args={ **arch_args_1b(), "quantization_args": { @@ -718,7 +675,6 @@ def llama3_2_quantized_models() -> List[Model]: quantization_format=CheckpointQuantizationFormat.int4, description="Llama 3.2 3b INT4 quantized LoRA", huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8", - recommended_sampling_params=recommended_sampling_params(), arch_args={ **arch_args_3b(), "quantization_args": { @@ -737,7 +693,6 @@ def llama3_2_quantized_models() -> List[Model]: quantization_format=CheckpointQuantizationFormat.int4, description="Llama 3.2 3b INT4 quantized SpinQuant", huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8", - recommended_sampling_params=recommended_sampling_params(), arch_args={ **arch_args_3b(), "quantization_args": { @@ -755,7 +710,6 @@ def llama3_2_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_2_1b_instruct, description="Llama 3.2 1b instruct model", huggingface_repo="meta-llama/Llama-3.2-1B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args=arch_args_1b(), pth_file_count=1, ), @@ -763,7 +717,6 @@ def llama3_2_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_2_3b_instruct, description="Llama 3.2 3b instruct model", huggingface_repo="meta-llama/Llama-3.2-3B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args=arch_args_3b(), pth_file_count=1, ), @@ -772,7 +725,6 @@ def llama3_2_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_2_11b_vision_instruct, description="Llama 3.2 11b vision instruct model", huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 4096, "n_layers": 32, @@ -794,7 +746,6 @@ def llama3_2_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_2_90b_vision_instruct, description="Llama 3.2 90b vision instruct model", huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -821,7 +772,6 @@ def llama3_3_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_3_70b_instruct, description="Llama 3.3 70b instruct", huggingface_repo="meta-llama/Llama-3.3-70B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -846,7 +796,6 @@ def safety_models() -> List[Model]: core_model_id=CoreModelId.llama_guard_3_11b_vision, description="Llama Guard v3 11b vision system safety model", huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 4096, "n_layers": 32, @@ -870,7 +819,6 @@ def safety_models() -> List[Model]: description="Llama Guard v3 1b 'int4' quantized system safety model", huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4", quantization_format=CheckpointQuantizationFormat.int4, - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 2048, "n_layers": 12, @@ -888,7 +836,6 @@ def safety_models() -> List[Model]: core_model_id=CoreModelId.llama_guard_3_1b, description="Llama Guard v3 1b system safety model", huggingface_repo="meta-llama/Llama-Guard-3-1B", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 2048, "n_layers": 16, diff --git a/llama_stack/models/llama/sku_types.py b/llama_stack/models/llama/sku_types.py new file mode 100644 index 000000000..88799b66d --- /dev/null +++ b/llama_stack/models/llama/sku_types.py @@ -0,0 +1,229 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Any, Dict, Optional + +from pydantic import BaseModel, ConfigDict, Field + + +class CheckpointQuantizationFormat(Enum): + # default format + bf16 = "bf16" + + # used for enabling fp8_rowwise inference, some weights are bf16 + fp8_mixed = "fp8-mixed" + + int8 = "int8" + + int4 = "int4" + + +class ModelFamily(Enum): + llama2 = "llama2" + llama3 = "llama3" + llama3_1 = "llama3_1" + llama3_2 = "llama3_2" + llama3_3 = "llama3_3" + llama4 = "llama4" + safety = "safety" + + +class CoreModelId(Enum): + """Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)""" + + # Llama 2 family + llama2_7b = "Llama-2-7b" + llama2_13b = "Llama-2-13b" + llama2_70b = "Llama-2-70b" + llama2_7b_chat = "Llama-2-7b-chat" + llama2_13b_chat = "Llama-2-13b-chat" + llama2_70b_chat = "Llama-2-70b-chat" + + # Llama 3 family + llama3_8b = "Llama-3-8B" + llama3_70b = "Llama-3-70B" + llama3_8b_instruct = "Llama-3-8B-Instruct" + llama3_70b_instruct = "Llama-3-70B-Instruct" + + # Llama 3.1 family + llama3_1_8b = "Llama3.1-8B" + llama3_1_70b = "Llama3.1-70B" + llama3_1_405b = "Llama3.1-405B" + llama3_1_8b_instruct = "Llama3.1-8B-Instruct" + llama3_1_70b_instruct = "Llama3.1-70B-Instruct" + llama3_1_405b_instruct = "Llama3.1-405B-Instruct" + + # Llama 3.2 family + llama3_2_1b = "Llama3.2-1B" + llama3_2_3b = "Llama3.2-3B" + llama3_2_1b_instruct = "Llama3.2-1B-Instruct" + llama3_2_3b_instruct = "Llama3.2-3B-Instruct" + llama3_2_11b_vision = "Llama3.2-11B-Vision" + llama3_2_90b_vision = "Llama3.2-90B-Vision" + llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct" + llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct" + + # Llama 3.3 family + llama3_3_70b_instruct = "Llama3.3-70B-Instruct" + + # Llama 4 family + llama4_scout_17b_16e = "Llama-4-Scout-17B-16E" + llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct" + llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E" + llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct" + + # Safety models + llama_guard_3_8b = "Llama-Guard-3-8B" + llama_guard_2_8b = "Llama-Guard-2-8B" + llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision" + llama_guard_3_1b = "Llama-Guard-3-1B" + + +def is_multimodal(model_id) -> bool: + if model_id in [ + CoreModelId.llama3_2_11b_vision, + CoreModelId.llama3_2_90b_vision, + CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_90b_vision_instruct, + ]: + return True + else: + return False + + +def model_family(model_id) -> ModelFamily: + if model_id in [ + CoreModelId.llama2_7b, + CoreModelId.llama2_13b, + CoreModelId.llama2_70b, + CoreModelId.llama2_7b_chat, + CoreModelId.llama2_13b_chat, + CoreModelId.llama2_70b_chat, + ]: + return ModelFamily.llama2 + elif model_id in [ + CoreModelId.llama3_8b, + CoreModelId.llama3_70b, + CoreModelId.llama3_8b_instruct, + CoreModelId.llama3_70b_instruct, + ]: + return ModelFamily.llama3 + elif model_id in [ + CoreModelId.llama3_1_8b, + CoreModelId.llama3_1_70b, + CoreModelId.llama3_1_405b, + CoreModelId.llama3_1_8b_instruct, + CoreModelId.llama3_1_70b_instruct, + CoreModelId.llama3_1_405b_instruct, + ]: + return ModelFamily.llama3_1 + elif model_id in [ + CoreModelId.llama3_2_1b, + CoreModelId.llama3_2_3b, + CoreModelId.llama3_2_1b_instruct, + CoreModelId.llama3_2_3b_instruct, + CoreModelId.llama3_2_11b_vision, + CoreModelId.llama3_2_90b_vision, + CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_90b_vision_instruct, + ]: + return ModelFamily.llama3_2 + elif model_id in [ + CoreModelId.llama3_3_70b_instruct, + ]: + return ModelFamily.llama3_3 + elif model_id in [ + CoreModelId.llama4_scout_17b_16e, + CoreModelId.llama4_scout_17b_16e_instruct, + CoreModelId.llama4_maverick_17b_128e, + CoreModelId.llama4_maverick_17b_128e_instruct, + ]: + return ModelFamily.llama4 + elif model_id in [ + CoreModelId.llama_guard_3_8b, + CoreModelId.llama_guard_2_8b, + CoreModelId.llama_guard_3_11b_vision, + CoreModelId.llama_guard_3_1b, + ]: + return ModelFamily.safety + else: + raise ValueError(f"Unknown model family for {model_id}") + + +class Model(BaseModel): + core_model_id: CoreModelId + description: str + huggingface_repo: Optional[str] = None + arch_args: Dict[str, Any] + variant: str = "" + + quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16 + pth_file_count: int + metadata: Dict[str, Any] = Field(default_factory=dict) + + # silence pydantic until we remove the `model_` fields + model_config = ConfigDict(protected_namespaces=()) + + @property + def model_family(self) -> ModelFamily: + return model_family(self.core_model_id) + + # The SKU is uniquely identified by (model_id, variant) combo + def descriptor(self, shorten_default_variant: bool = True) -> str: + if not self.variant: + return self.core_model_id.value + return f"{self.core_model_id.value}:{self.variant}" + + @property + def is_instruct_model(self) -> bool: + return "instruct" in self.core_model_id.value + + # Featured models are shown in the non-exhaustive model list + @property + def is_featured(self) -> bool: + return self.model_family in [ + ModelFamily.llama3_1, + ModelFamily.llama3_2, + ModelFamily.llama3_3, + ModelFamily.llama4, + ModelFamily.safety, + ] + + @property + def max_seq_length(self) -> int: + if self.model_family == ModelFamily.llama2: + return 4096 + elif self.core_model_id == CoreModelId.llama_guard_2_8b: + return 4096 + elif self.model_family == ModelFamily.llama3: + return 8192 + elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]: + return 131072 + elif self.model_family == ModelFamily.llama3_2: + if self.quantization_format == CheckpointQuantizationFormat.int4: + return 8192 + return 131072 + elif self.model_family == ModelFamily.llama4: + if self.core_model_id in { + CoreModelId.llama4_scout_17b_16e, + CoreModelId.llama4_maverick_17b_128e, + }: + return 262144 + if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct: + return 10485760 + if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct: + return 1048576 + + raise AssertionError(f"Unexpected core model id: {self.core_model_id}") + elif self.core_model_id in [ + CoreModelId.llama_guard_3_8b, + CoreModelId.llama_guard_3_11b_vision, + CoreModelId.llama_guard_3_1b, + ]: + return 131072 + else: + raise ValueError(f"Unknown max_seq_len for {self.core_model_id}") diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index e1af4ab71..6840da89f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -52,6 +52,7 @@ from llama_stack.apis.inference import ( StopReason, SystemMessage, ToolDefinition, + ToolParamDefinition, ToolResponse, ToolResponseMessage, UserMessage, @@ -63,7 +64,6 @@ from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( BuiltinTool, ToolCall, - ToolParamDefinition, ) from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.telemetry import tracing diff --git a/llama_stack/providers/inline/inference/meta_reference/__init__.py b/llama_stack/providers/inline/inference/meta_reference/__init__.py index 3ef7cfd45..3710766e2 100644 --- a/llama_stack/providers/inline/inference/meta_reference/__init__.py +++ b/llama_stack/providers/inline/inference/meta_reference/__init__.py @@ -4,13 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Union +from typing import Any, Dict -from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig +from .config import MetaReferenceInferenceConfig async def get_provider_impl( - config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig], + config: MetaReferenceInferenceConfig, _deps: Dict[str, Any], ): from .inference import MetaReferenceInferenceImpl diff --git a/llama_stack/providers/inline/inference/meta_reference/common.py b/llama_stack/providers/inline/inference/meta_reference/common.py index 3dc5e89f9..beb0d39d4 100644 --- a/llama_stack/providers/inline/inference/meta_reference/common.py +++ b/llama_stack/providers/inline/inference/meta_reference/common.py @@ -5,19 +5,10 @@ # the root directory of this source tree. from pathlib import Path -from typing import List, Optional - -from pydantic import BaseModel from llama_stack.distribution.utils.model_utils import model_local_dir -class TokenResult(BaseModel): - token: int - text: str - logprobs: Optional[List[float]] = None - - def model_checkpoint_dir(model_id) -> str: checkpoint_dir = Path(model_local_dir(model_id)) diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 9e5f7747e..315667506 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -21,6 +21,7 @@ class MetaReferenceInferenceConfig(BaseModel): torch_seed: Optional[int] = None max_seq_len: int = 4096 max_batch_size: int = 1 + model_parallel_size: Optional[int] = None # when this is False, we assume that the distributed process group is setup by someone # outside of this code (e.g., when run inside `torchrun`). that is useful for clients @@ -31,6 +32,8 @@ class MetaReferenceInferenceConfig(BaseModel): # can override by specifying the directory explicitly checkpoint_dir: Optional[str] = None + quantization: Optional[QuantizationConfig] = None + @field_validator("model") @classmethod def validate_model(cls, model: str) -> str: @@ -47,27 +50,16 @@ class MetaReferenceInferenceConfig(BaseModel): cls, model: str = "Llama3.2-3B-Instruct", checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}", + quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}", + model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}", **kwargs, ) -> Dict[str, Any]: return { "model": model, "max_seq_len": 4096, "checkpoint_dir": checkpoint_dir, + "quantization": { + "type": quantization_type, + }, + "model_parallel_size": model_parallel_size, } - - -class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig): - quantization: QuantizationConfig - - @classmethod - def sample_run_config( - cls, - model: str = "Llama3.2-3B-Instruct", - checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}", - **kwargs, - ) -> Dict[str, Any]: - config = super().sample_run_config(model, checkpoint_dir, **kwargs) - config["quantization"] = { - "type": "fp8", - } - return config diff --git a/llama_stack/providers/inline/inference/meta_reference/generators.py b/llama_stack/providers/inline/inference/meta_reference/generators.py index 4b0ed7ecd..65bed4d8c 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generators.py +++ b/llama_stack/providers/inline/inference/meta_reference/generators.py @@ -11,19 +11,18 @@ import torch from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from llama_stack.apis.inference import ( - Fp8QuantizationConfig, - Int4QuantizationConfig, + GreedySamplingStrategy, JsonSchemaResponseFormat, ResponseFormat, -) -from llama_stack.models.llama.datatypes import ( - GreedySamplingStrategy, - Model, SamplingParams, TopPSamplingStrategy, ) +from llama_stack.models.llama.datatypes import QuantizationMode +from llama_stack.models.llama.llama3.generation import Llama3 from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer +from llama_stack.models.llama.llama4.generation import Llama4 from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer +from llama_stack.models.llama.sku_types import Model from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, @@ -31,10 +30,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( ) from .common import model_checkpoint_dir -from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig +from .config import MetaReferenceInferenceConfig from .inference import resolve_model -from .llama3.generation import Llama3 -from .llama4.generation import Llama4 Tokenizer = Llama4Tokenizer | Llama3Tokenizer @@ -116,10 +113,11 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent): return get_default_tool_prompt_format(request.model) +# TODO: combine Llama3 and Llama4 generators since they are almost identical now class Llama4Generator: def __init__( self, - config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig, + config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model, ): @@ -134,11 +132,13 @@ class Llama4Generator: # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value ckpt_dir = model_checkpoint_dir(resolved_model.descriptor()) - if isinstance(config, MetaReferenceQuantizedInferenceConfig): - if isinstance(config.quantization, Fp8QuantizationConfig): - quantization_mode = "fp8_mixed" - elif isinstance(config.quantization, Int4QuantizationConfig): - quantization_mode = "int4_mixed" + if config.quantization: + if config.quantization.type == "fp8_mixed": + quantization_mode = QuantizationMode.fp8_mixed + elif config.quantization.type == "int4_mixed": + quantization_mode = QuantizationMode.int4_mixed + elif config.quantization.type == "bf16": + quantization_mode = None else: raise ValueError(f"Unsupported quantization mode {config.quantization}") else: @@ -148,7 +148,7 @@ class Llama4Generator: ckpt_dir=ckpt_dir, max_seq_len=config.max_seq_len, max_batch_size=config.max_batch_size, - world_size=llama_model.pth_file_count, + world_size=config.model_parallel_size or llama_model.pth_file_count, quantization_mode=quantization_mode, ) @@ -166,8 +166,8 @@ class Llama4Generator: max_gen_len = self.args.max_seq_len - 1 temperature, top_p = _infer_sampling_params(sampling_params) - yield from self.inner_generator.generate( - llm_input=self.formatter.encode_content(request.content), + for result in self.inner_generator.generate( + llm_inputs=[self.formatter.encode_content(request.content)], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, @@ -178,7 +178,8 @@ class Llama4Generator: self.args.vocab_size, request.response_format, ), - ) + ): + yield result[0] def chat_completion( self, @@ -190,8 +191,8 @@ class Llama4Generator: max_gen_len = self.args.max_seq_len - 1 temperature, top_p = _infer_sampling_params(sampling_params) - yield from self.inner_generator.generate( - llm_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)), + for result in self.inner_generator.generate( + llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, @@ -202,20 +203,46 @@ class Llama4Generator: self.args.vocab_size, request.response_format, ), - ) + ): + yield result[0] class Llama3Generator: def __init__( self, - config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig, + config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model, ): + if config.checkpoint_dir and config.checkpoint_dir != "null": + ckpt_dir = config.checkpoint_dir + else: + resolved_model = resolve_model(model_id) + if resolved_model is None: + # if the model is not a native llama model, get the default checkpoint_dir based on model id + ckpt_dir = model_checkpoint_dir(model_id) + else: + # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value + ckpt_dir = model_checkpoint_dir(resolved_model.descriptor()) + + if config.quantization: + if config.quantization.type == "fp8_mixed": + quantization_mode = QuantizationMode.fp8_mixed + elif config.quantization.type == "int4_mixed": + quantization_mode = QuantizationMode.int4_mixed + elif config.quantization.type == "bf16": + quantization_mode = None + else: + raise ValueError(f"Unsupported quantization mode {config.quantization}") + else: + quantization_mode = None + self.inner_generator = Llama3.build( - config=config, - model_id=model_id, - llama_model=llama_model, + ckpt_dir=ckpt_dir, + max_seq_len=config.max_seq_len, + max_batch_size=config.max_batch_size, + world_size=config.model_parallel_size or llama_model.pth_file_count, + quantization_mode=quantization_mode, ) self.tokenizer = self.inner_generator.tokenizer self.args = self.inner_generator.args @@ -231,8 +258,8 @@ class Llama3Generator: max_gen_len = self.args.max_seq_len - 1 temperature, top_p = _infer_sampling_params(sampling_params) - yield from self.inner_generator.generate( - model_input=self.formatter.encode_content(request.content), + for result in self.inner_generator.generate( + llm_inputs=[self.formatter.encode_content(request.content)], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, @@ -243,7 +270,8 @@ class Llama3Generator: self.args.vocab_size, request.response_format, ), - ) + ): + yield result[0] def chat_completion( self, @@ -255,8 +283,8 @@ class Llama3Generator: max_gen_len = self.args.max_seq_len - 1 temperature, top_p = _infer_sampling_params(sampling_params) - yield from self.inner_generator.generate( - model_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)), + for result in self.inner_generator.generate( + llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, @@ -267,4 +295,5 @@ class Llama3Generator: self.args.vocab_size, request.response_format, ), - ) + ): + yield result[0] diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index da217728b..5f81d6421 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -31,23 +31,21 @@ from llama_stack.apis.inference import ( LogProbConfig, Message, ResponseFormat, + SamplingParams, + StopReason, TokenLogProbs, ToolChoice, ToolConfig, -) -from llama_stack.apis.models import Model, ModelType -from llama_stack.models.llama.datatypes import ( - ModelFamily, - SamplingParams, - StopReason, ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.models import Model, ModelType from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.sku_list import resolve_model +from llama_stack.models.llama.sku_types import ModelFamily from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -151,7 +149,7 @@ class MetaReferenceInferenceImpl( if self.config.create_distributed_process_group: self.generator = LlamaModelParallelGenerator( - model_parallel_size=llama_model.pth_file_count, + model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count, builder_fn=builder_fn, builder_params=builder_params, formatter=( diff --git a/llama_stack/providers/inline/inference/meta_reference/llama3/generation.py b/llama_stack/providers/inline/inference/meta_reference/llama3/generation.py deleted file mode 100644 index 3805e4310..000000000 --- a/llama_stack/providers/inline/inference/meta_reference/llama3/generation.py +++ /dev/null @@ -1,346 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -import json -import os -import sys -import time -from pathlib import Path -from typing import Callable, Generator, Optional, Union - -import torch -import torch.nn.functional as F -from fairscale.nn.model_parallel.initialize import ( - get_model_parallel_rank, - initialize_model_parallel, - model_parallel_is_initialized, -) - -from llama_stack.apis.inference import ( - Fp8QuantizationConfig, - Int4QuantizationConfig, -) -from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import Model -from llama_stack.models.llama.llama3.chat_format import ChatFormat, LLMInput -from llama_stack.models.llama.llama3.tokenizer import Tokenizer -from llama_stack.models.llama.sku_list import resolve_model - -from ..common import TokenResult, model_checkpoint_dir -from ..config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig -from .args import ModelArgs -from .model import Transformer -from .multimodal.model import CrossAttentionTransformer - -log = get_logger(__name__, category="inference") - - -class Llama3: - @staticmethod - def build( - config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig], - model_id: str, - llama_model: Model, - ): - """ - Build a Llama instance by initializing and loading a model checkpoint. - - Note: - This method initializes the distributed process group, sets the device to CUDA, - and loads the pre-trained model and tokenizer. - """ - if "DEVICE" in os.environ: - device = os.environ.get("DEVICE") - if device == "cuda": - assert torch.cuda.is_available(), "PyTorch CUDA backend not available" - if device == "xpu": - assert torch.xpu.is_available(), "PyTorch XPU backend not available" - else: - if torch.cuda.is_available(): - device = "cuda" - elif torch.xpu.is_available(): - device = "xpu" - else: - device = "cpu" - log.info(f"Using {device} device") - - llama_model_id = llama_model.core_model_id.value - if not torch.distributed.is_initialized(): - if device == "cuda": - torch.distributed.init_process_group("nccl") - else: - torch.distributed.init_process_group("gloo") - - model_parallel_size = llama_model.pth_file_count - - if not model_parallel_is_initialized(): - initialize_model_parallel(model_parallel_size) - - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - if device == "cuda": - torch.cuda.set_device(local_rank) - elif device == "xpu": - torch.xpu.set_device(local_rank) - - # seed must be the same in all processes - if config.torch_seed is not None: - torch.manual_seed(config.torch_seed) - - if local_rank > 0: - sys.stdout = open(os.devnull, "w") - - start_time = time.time() - if config.checkpoint_dir and config.checkpoint_dir != "null": - ckpt_dir = config.checkpoint_dir - else: - resolved_model = resolve_model(model_id) - if resolved_model is None: - # if the model is not a native llama model, get the default checkpoint_dir based on model id - ckpt_dir = model_checkpoint_dir(model_id) - else: - # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value - ckpt_dir = model_checkpoint_dir(resolved_model.descriptor()) - - checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) - assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" - assert model_parallel_size == len(checkpoints), ( - f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" - ) - ckpt_path = checkpoints[get_model_parallel_rank()] - state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) - with open(Path(ckpt_dir) / "params.json", "r") as f: - params = json.loads(f.read()) - - if "model" in params: - params = params["model"] - - model_args: ModelArgs = ModelArgs( - max_seq_len=config.max_seq_len, - max_batch_size=config.max_batch_size, - **params, - ) - - tokenizer = Tokenizer.get_instance() - assert model_args.vocab_size == tokenizer.n_words, ( - f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" - ) - - if isinstance(config, MetaReferenceQuantizedInferenceConfig): - if isinstance(config.quantization, Fp8QuantizationConfig): - from .quantization.loader import convert_to_fp8_quantized_model - - # load on CPU in bf16 so that fp8 conversion does not find an - # unexpected (fp32, e.g.) datatype - torch.set_default_tensor_type(torch.BFloat16Tensor) - if model_args.vision_chunk_size > 0: - model = CrossAttentionTransformer(model_args) - model.setup_cache(model_args.max_batch_size, torch.bfloat16) - else: - model = Transformer(model_args) - model.load_state_dict(state_dict, strict=False) - model = convert_to_fp8_quantized_model(model, config, ckpt_dir) - elif isinstance(config.quantization, Int4QuantizationConfig): - from .quantization.loader import convert_to_int4_quantized_model - - model = Transformer(model_args) - model = convert_to_int4_quantized_model(model, model_args, config) - model.load_state_dict(state_dict, strict=True) - - if model_args.quantization_args is not None and model_args.quantization_args.spinquant: - # Add a wrapper for adding hadamard transform for spinquant. - # This needs to be done after loading the state dict otherwise an error will be raised while - # loading the state dict. - from ..hadamard_utils import ( - add_hadamard_transform_for_spinquant, - ) - - add_hadamard_transform_for_spinquant(model) - else: - raise NotImplementedError("Currently int4 and fp8 are the only supported quantization methods.") - else: - if device == "cuda": - if torch.cuda.is_bf16_supported(): - torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) - else: - torch.set_default_tensor_type(torch.cuda.HalfTensor) - else: - torch.set_default_device(device) - if device == "xpu" and torch.xpu.is_bf16_supported(): - torch.set_default_dtype(torch.bfloat16) - else: - torch.set_default_dtype(torch.half) - if model_args.vision_chunk_size > 0: - model = CrossAttentionTransformer(model_args) - model.setup_cache(model_args.max_batch_size, torch.bfloat16) - else: - model = Transformer(model_args) - model.load_state_dict(state_dict, strict=False) - - model.to(device) - - log.info(f"Loaded in {time.time() - start_time:.2f} seconds") - return Llama3(model, tokenizer, model_args, llama_model_id) - - def __init__( - self, - model: Transformer, - tokenizer: Tokenizer, - args: ModelArgs, - llama_model: str, - ): - self.args = args - self.model = model - self.tokenizer = tokenizer - self.formatter = ChatFormat(tokenizer) - self.llama_model = llama_model - - @torch.inference_mode() - def generate( - self, - model_input: LLMInput, - max_gen_len: int, - temperature: float = 0.6, - top_p: float = 0.9, - logprobs: bool = False, - echo: bool = False, - print_input_tokens: bool = False, - logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - ) -> Generator: - params = self.model.params - - if print_input_tokens: - input_tokens = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens] - log.info("Input to model -> " + self.tokenizer.decode(input_tokens)) - prompt_tokens = [model_input.tokens] - - bsz = 1 - assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) - - min_prompt_len = min(len(t) for t in prompt_tokens) - max_prompt_len = max(len(t) for t in prompt_tokens) - - if max_prompt_len >= params.max_seq_len: - log.error(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}") - return - - total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) - - is_vision = isinstance(self.model, CrossAttentionTransformer) - if is_vision: - images = model_input.vision.images if model_input.vision is not None else [] - mask = model_input.vision.mask if model_input.vision is not None else [] - - # the method works for bsz > 1 so add a batch dimension - xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( - batch_images=[images], - batch_masks=[mask], - total_len=total_len, - ) - - pad_id = self.tokenizer.pad_id - tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long) - for k, t in enumerate(prompt_tokens): - tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long) - if logprobs: - token_logprobs = torch.zeros_like(tokens) - - prev_pos = 0 - eos_reached = torch.tensor([False] * bsz) - input_text_mask = tokens != pad_id - if min_prompt_len == total_len: - # TODO(ashwin): unify this branch with the one below and figure out multimodal crap - logits = self.model.forward(tokens, prev_pos) - token_logprobs = -F.cross_entropy( - input=logits.transpose(1, 2), - target=tokens, - reduction="none", - ignore_index=pad_id, - ) - - stop_tokens = torch.tensor(self.tokenizer.stop_tokens) - for cur_pos in range(min_prompt_len, total_len): - if is_vision: - position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long) - logits = self.model.forward( - position_ids, - tokens, - cross_attention_masks, - full_text_row_masked_out_mask, - xattn_caches, - ) - else: - logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) - - if logits_processor is not None: - logits = logits_processor(tokens[:, :cur_pos], logits) - - if temperature > 0: - probs = torch.softmax(logits[:, -1] / temperature, dim=-1) - next_token = sample_top_p(probs, top_p) - else: - next_token = torch.argmax(logits[:, -1], dim=-1) - - next_token = next_token.reshape(-1) - # only replace token if prompt has already been generated - next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) - tokens[:, cur_pos] = next_token - - target = tokens[:, prev_pos + 1 : cur_pos + 1] - if is_vision: - # the logits space (num_classes) is designed to never contain a media_token - # however our input token stream does contain them. we need to nuke them here - # or else the CUDA kernels will crash with an illegal memory access - vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256] - masks = [target.eq(t) for t in vision_tokens] - if len(masks) > 1: - mask = torch.logical_or(*masks) - else: - mask = masks[0] - target[mask] = 0 - - if logprobs: - token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( - input=logits.transpose(1, 2), - target=tokens[:, prev_pos + 1 : cur_pos + 1], - reduction="none", - ignore_index=pad_id, - ) - eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens)) - yield TokenResult( - token=next_token[0].item(), - text=self.tokenizer.decode(next_token.tolist()), - logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None), - ) - - prev_pos = cur_pos - if all(eos_reached): - break - - -def sample_top_p(probs, p): - """ - Perform top-p (nucleus) sampling on a probability distribution. - - Args: - probs (torch.Tensor): Probability distribution tensor. - p (float): Probability threshold for top-p sampling. - - Returns: - torch.Tensor: Sampled token indices. - - Note: - Top-p sampling selects the smallest set of tokens whose cumulative probability mass - exceeds the threshold p. The distribution is renormalized based on the selected tokens. - """ - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - return next_token diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index e8767c2ff..74fc49d5e 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -32,13 +32,12 @@ from pydantic import BaseModel, Field from torch.distributed.launcher.api import LaunchConfig, elastic_launch from typing_extensions import Annotated +from llama_stack.models.llama.datatypes import GenerationResult from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, ) -from .common import TokenResult - log = logging.getLogger(__name__) @@ -75,7 +74,7 @@ class TaskRequest(BaseModel): class TaskResponse(BaseModel): type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response - result: TokenResult + result: GenerationResult class ExceptionResponse(BaseModel): diff --git a/llama_stack/providers/inline/inference/vllm/openai_utils.py b/llama_stack/providers/inline/inference/vllm/openai_utils.py index 90b5398f9..d34f5ad5f 100644 --- a/llama_stack/providers/inline/inference/vllm/openai_utils.py +++ b/llama_stack/providers/inline/inference/vllm/openai_utils.py @@ -14,9 +14,10 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, Message, ToolChoice, + ToolDefinition, UserMessage, ) -from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition +from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 256e0f821..ea2643b7a 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -46,6 +46,8 @@ from llama_stack.apis.inference import ( TokenLogProbs, ToolChoice, ToolConfig, + TopKSamplingStrategy, + TopPSamplingStrategy, ) from llama_stack.apis.models import Model from llama_stack.log import get_logger @@ -55,8 +57,6 @@ from llama_stack.models.llama.datatypes import ( ToolCall, ToolDefinition, ToolPromptFormat, - TopKSamplingStrategy, - TopPSamplingStrategy, ) from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index f8a1c0436..a040ca1b0 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -22,8 +22,8 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b from torchtune.modules.transforms import Transform from llama_stack.apis.post_training import DatasetFormat -from llama_stack.models.llama.datatypes import Model from llama_stack.models.llama.sku_list import resolve_model +from llama_stack.models.llama.sku_types import Model BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildTokenizerCallable = Callable[..., Llama3Tokenizer] diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index e514e3781..d95c40976 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -23,7 +23,8 @@ from llama_stack.apis.safety import ( ) from llama_stack.apis.shields import Shield from llama_stack.distribution.datatypes import Api -from llama_stack.models.llama.datatypes import CoreModelId, Role +from llama_stack.models.llama.datatypes import Role +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 51ea4cbef..5f9ae421f 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -24,6 +24,8 @@ META_REFERENCE_DEPS = [ "zmq", "lm-format-enforcer", "sentence-transformers", + "torchao==0.5.0", + "fbgemm-gpu-genai==1.1.2", ] @@ -36,13 +38,6 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.inline.inference.meta_reference", config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig", ), - InlineProviderSpec( - api=Api.inference, - provider_type="inline::meta-reference-quantized", - pip_packages=META_REFERENCE_DEPS + ["fbgemm-gpu", "torchao==0.5.0"], - module="llama_stack.providers.inline.inference.meta_reference", - config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig", - ), InlineProviderSpec( api=Api.inference, provider_type="inline::vllm", diff --git a/llama_stack/providers/remote/inference/bedrock/models.py b/llama_stack/providers/remote/inference/bedrock/models.py index c5079799f..ec8120049 100644 --- a/llama_stack/providers/remote/inference/bedrock/models.py +++ b/llama_stack/providers/remote/inference/bedrock/models.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( build_hf_repo_model_entry, ) diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index a53e6e5a5..43d986b86 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -28,8 +28,8 @@ from llama_stack.apis.inference import ( ToolConfig, ToolDefinition, ToolPromptFormat, + TopKSamplingStrategy, ) -from llama_stack.models.llama.datatypes import TopKSamplingStrategy from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) diff --git a/llama_stack/providers/remote/inference/cerebras/models.py b/llama_stack/providers/remote/inference/cerebras/models.py index 37419bf4c..38301b32a 100644 --- a/llama_stack/providers/remote/inference/cerebras/models.py +++ b/llama_stack/providers/remote/inference/cerebras/models.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( build_hf_repo_model_entry, ) diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 53a9c04f4..0eaf0135b 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -28,7 +28,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_hf_repo_model_entry, diff --git a/llama_stack/providers/remote/inference/fireworks/models.py b/llama_stack/providers/remote/inference/fireworks/models.py index a0dc11768..4975d061f 100644 --- a/llama_stack/providers/remote/inference/fireworks/models.py +++ b/llama_stack/providers/remote/inference/fireworks/models.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.models.models import ModelType -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ProviderModelEntry, build_hf_repo_model_entry, diff --git a/llama_stack/providers/remote/inference/nvidia/models.py b/llama_stack/providers/remote/inference/nvidia/models.py index 879855003..964125148 100644 --- a/llama_stack/providers/remote/inference/nvidia/models.py +++ b/llama_stack/providers/remote/inference/nvidia/models.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.models import ModelType -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ProviderModelEntry, build_hf_repo_model_entry, diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 5caf19fda..e1f5d7a6a 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -29,15 +29,13 @@ from llama_stack.apis.inference import ( LogProbConfig, Message, ResponseFormat, + SamplingParams, TextTruncation, ToolChoice, ToolConfig, -) -from llama_stack.models.llama.datatypes import ( - SamplingParams, ToolDefinition, - ToolPromptFormat, ) +from llama_stack.models.llama.datatypes import ToolPromptFormat from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index 0582cb816..3f2769b26 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -19,11 +19,9 @@ from llama_stack.apis.inference import ( CompletionRequest, CompletionResponse, CompletionResponseStreamChunk, + GreedySamplingStrategy, JsonSchemaResponseFormat, TokenLogProbs, -) -from llama_stack.models.llama.datatypes import ( - GreedySamplingStrategy, TopKSamplingStrategy, TopPSamplingStrategy, ) diff --git a/llama_stack/providers/remote/inference/ollama/models.py b/llama_stack/providers/remote/inference/ollama/models.py index be556762c..42e364105 100644 --- a/llama_stack/providers/remote/inference/ollama/models.py +++ b/llama_stack/providers/remote/inference/ollama/models.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.models.models import ModelType -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ProviderModelEntry, build_hf_repo_model_entry, diff --git a/llama_stack/providers/remote/inference/sambanova/models.py b/llama_stack/providers/remote/inference/sambanova/models.py index 2231be22d..9589ea268 100644 --- a/llama_stack/providers/remote/inference/sambanova/models.py +++ b/llama_stack/providers/remote/inference/sambanova/models.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( build_hf_repo_model_entry, ) diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 635a42d38..a3badd468 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -21,6 +21,7 @@ from llama_stack.apis.inference import ( CompletionMessage, EmbeddingsResponse, EmbeddingTaskType, + GreedySamplingStrategy, Inference, LogProbConfig, Message, @@ -35,12 +36,9 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ToolResponseMessage, - UserMessage, -) -from llama_stack.models.llama.datatypes import ( - GreedySamplingStrategy, TopKSamplingStrategy, TopPSamplingStrategy, + UserMessage, ) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( diff --git a/llama_stack/providers/remote/inference/together/models.py b/llama_stack/providers/remote/inference/together/models.py index 63d3d94b5..f014c03f0 100644 --- a/llama_stack/providers/remote/inference/together/models.py +++ b/llama_stack/providers/remote/inference/together/models.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.models.models import ModelType -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ProviderModelEntry, build_hf_repo_model_entry, diff --git a/llama_stack/providers/remote/post_training/nvidia/models.py b/llama_stack/providers/remote/post_training/nvidia/models.py index 04a9af38c..7c696ac20 100644 --- a/llama_stack/providers/remote/post_training/nvidia/models.py +++ b/llama_stack/providers/remote/post_training/nvidia/models.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ProviderModelEntry, build_hf_repo_model_entry, diff --git a/llama_stack/providers/tests/report.py b/llama_stack/providers/tests/report.py index c9a7f69a8..bc29534be 100644 --- a/llama_stack/providers/tests/report.py +++ b/llama_stack/providers/tests/report.py @@ -12,8 +12,8 @@ import pytest from pytest import ExitCode from pytest_html.basereport import _process_outcome -from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_list import all_registered_models +from llama_stack.models.llama.sku_types import CoreModelId INFERENCE_APIS = ["chat_completion"] FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"] diff --git a/llama_stack/providers/utils/inference/__init__.py b/llama_stack/providers/utils/inference/__init__.py index a885da235..e36be9404 100644 --- a/llama_stack/providers/utils/inference/__init__.py +++ b/llama_stack/providers/utils/inference/__init__.py @@ -6,8 +6,8 @@ from typing import List -from llama_stack.models.llama.datatypes import * # noqa: F403 from llama_stack.models.llama.sku_list import all_registered_models +from llama_stack.models.llama.sku_types import * # noqa: F403 def is_supported_safety_model(model: Model) -> bool: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index e475d77b6..44a89dfb0 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -73,21 +73,21 @@ from llama_stack.apis.inference import ( CompletionMessage, CompletionResponse, CompletionResponseStreamChunk, + GreedySamplingStrategy, Message, + SamplingParams, SystemMessage, TokenLogProbs, ToolResponseMessage, + TopKSamplingStrategy, + TopPSamplingStrategy, UserMessage, ) from llama_stack.models.llama.datatypes import ( BuiltinTool, - GreedySamplingStrategy, - SamplingParams, StopReason, ToolCall, ToolDefinition, - TopKSamplingStrategy, - TopPSamplingStrategy, ) from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 0231312cc..4f9c4927a 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -34,7 +34,6 @@ from llama_stack.apis.inference import ( ) from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( - ModelFamily, RawContent, RawContentItem, RawMediaItem, @@ -43,7 +42,6 @@ from llama_stack.models.llama.datatypes import ( Role, StopReason, ToolPromptFormat, - is_multimodal, ) from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.prompt_templates import ( @@ -55,6 +53,7 @@ from llama_stack.models.llama.llama3.prompt_templates import ( ) from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.sku_list import resolve_model +from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.providers.utils.inference import supported_inference_models log = get_logger(name=__name__, category="inference") diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json index 931240d37..b8f475cea 100644 --- a/llama_stack/templates/dependencies.json +++ b/llama_stack/templates/dependencies.json @@ -356,50 +356,7 @@ "fairscale", "faiss-cpu", "fastapi", - "fire", - "httpx", - "langdetect", - "lm-format-enforcer", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "pythainlp", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentence-transformers", - "sentencepiece", - "torch", - "torchvision", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", - "zmq" - ], - "meta-reference-quantized-gpu": [ - "accelerate", - "aiosqlite", - "autoevals", - "blobfile", - "chardet", - "chromadb-client", - "datasets", - "emoji", - "fairscale", - "faiss-cpu", - "fastapi", - "fbgemm-gpu", + "fbgemm-gpu-genai==1.1.2", "fire", "httpx", "langdetect", diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml index 2cf49cc36..9f97158f8 100644 --- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml @@ -18,6 +18,9 @@ providers: model: ${env.INFERENCE_MODEL} max_seq_len: 4096 checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} + quantization: + type: ${env.QUANTIZATION_TYPE:bf16} + model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0} - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} @@ -27,6 +30,9 @@ providers: model: ${env.SAFETY_MODEL} max_seq_len: 4096 checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null} + quantization: + type: ${env.QUANTIZATION_TYPE:bf16} + model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0} vector_io: - provider_id: faiss provider_type: inline::faiss diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index 964dfafeb..eda332123 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -18,6 +18,9 @@ providers: model: ${env.INFERENCE_MODEL} max_seq_len: 4096 checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} + quantization: + type: ${env.QUANTIZATION_TYPE:bf16} + model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0} - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} diff --git a/llama_stack/templates/meta-reference-quantized-gpu/build.yaml b/llama_stack/templates/meta-reference-quantized-gpu/build.yaml deleted file mode 100644 index 7bbcfe5f2..000000000 --- a/llama_stack/templates/meta-reference-quantized-gpu/build.yaml +++ /dev/null @@ -1,32 +0,0 @@ -version: '2' -distribution_spec: - description: Use Meta Reference with fp8, int4 quantization for running LLM inference - providers: - inference: - - inline::meta-reference-quantized - vector_io: - - inline::faiss - - remote::chromadb - - remote::pgvector - safety: - - inline::llama-guard - agents: - - inline::meta-reference - telemetry: - - inline::meta-reference - eval: - - inline::meta-reference - datasetio: - - remote::huggingface - - inline::localfs - scoring: - - inline::basic - - inline::llm-as-judge - - inline::braintrust - tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::code-interpreter - - inline::rag-runtime - - remote::model-context-protocol -image_type: conda diff --git a/llama_stack/templates/meta-reference-quantized-gpu/doc_template.md b/llama_stack/templates/meta-reference-quantized-gpu/doc_template.md deleted file mode 100644 index 1855da6c9..000000000 --- a/llama_stack/templates/meta-reference-quantized-gpu/doc_template.md +++ /dev/null @@ -1,113 +0,0 @@ ---- -orphan: true ---- -# Meta Reference Quantized Distribution - -```{toctree} -:maxdepth: 2 -:hidden: - -self -``` - -The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations: - -{{ providers_table }} - -The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc. - -Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs. - -{% if run_config_env_vars %} -### Environment Variables - -The following environment variables can be configured: - -{% for var, (default_value, description) in run_config_env_vars.items() %} -- `{{ var }}`: {{ description }} (default: `{{ default_value }}`) -{% endfor %} -{% endif %} - - -## Prerequisite: Downloading Models - -Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. - -``` -$ llama model list --downloaded -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓ -┃ Model ┃ Size ┃ Modified Time ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩ -│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │ -├─────────────────────────────────────────┼──────────┼─────────────────────┤ -│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │ -├─────────────────────────────────────────┼──────────┼─────────────────────┤ -│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │ -├─────────────────────────────────────────┼──────────┼─────────────────────┤ -│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │ -├─────────────────────────────────────────┼──────────┼─────────────────────┤ -│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │ -├─────────────────────────────────────────┼──────────┼─────────────────────┤ -│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │ -├─────────────────────────────────────────┼──────────┼─────────────────────┤ -│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │ -├─────────────────────────────────────────┼──────────┼─────────────────────┤ -│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │ -├─────────────────────────────────────────┼──────────┼─────────────────────┤ -│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │ -└─────────────────────────────────────────┴──────────┴─────────────────────┘ -``` - -## Running the Distribution - -You can do this via Conda (build code) or Docker which has a pre-built image. - -### Via Docker - -This method allows you to get started quickly without having to build the distribution code. - -```bash -LLAMA_STACK_PORT=8321 -docker run \ - -it \ - --pull always \ - -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - -v ~/.llama:/root/.llama \ - llamastack/distribution-{{ name }} \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct -``` - -If you are using Llama Stack Safety / Shield APIs, use: - -```bash -docker run \ - -it \ - --pull always \ - -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - -v ~/.llama:/root/.llama \ - llamastack/distribution-{{ name }} \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ - --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B -``` - -### Via Conda - -Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available. - -```bash -llama stack build --template {{ name }} --image-type conda -llama stack run distributions/{{ name }}/run.yaml \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct -``` - -If you are using Llama Stack Safety / Shield APIs, use: - -```bash -llama stack run distributions/{{ name }}/run-with-safety.yaml \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ - --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B -``` diff --git a/llama_stack/templates/meta-reference-quantized-gpu/meta_reference.py b/llama_stack/templates/meta-reference-quantized-gpu/meta_reference.py deleted file mode 100644 index c46ea8bc6..000000000 --- a/llama_stack/templates/meta-reference-quantized-gpu/meta_reference.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from pathlib import Path - -from llama_stack.apis.models.models import ModelType -from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput -from llama_stack.providers.inline.inference.meta_reference import ( - MetaReferenceQuantizedInferenceConfig, -) -from llama_stack.providers.inline.inference.sentence_transformers import ( - SentenceTransformersInferenceConfig, -) -from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig -from llama_stack.templates.template import DistributionTemplate, RunConfigSettings - - -def get_distribution_template() -> DistributionTemplate: - providers = { - "inference": ["inline::meta-reference-quantized"], - "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "eval": ["inline::meta-reference"], - "datasetio": ["remote::huggingface", "inline::localfs"], - "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], - "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "inline::code-interpreter", - "inline::rag-runtime", - "remote::model-context-protocol", - ], - } - default_tool_groups = [ - ToolGroupInput( - toolgroup_id="builtin::websearch", - provider_id="tavily-search", - ), - ToolGroupInput( - toolgroup_id="builtin::rag", - provider_id="rag-runtime", - ), - ToolGroupInput( - toolgroup_id="builtin::code_interpreter", - provider_id="code-interpreter", - ), - ] - name = "meta-reference-quantized-gpu" - inference_provider = Provider( - provider_id="meta-reference-inference", - provider_type="inline::meta-reference-quantized", - config=MetaReferenceQuantizedInferenceConfig.sample_run_config( - model="${env.INFERENCE_MODEL}", - checkpoint_dir="${env.INFERENCE_CHECKPOINT_DIR:null}", - ), - ) - embedding_provider = Provider( - provider_id="sentence-transformers", - provider_type="inline::sentence-transformers", - config=SentenceTransformersInferenceConfig.sample_run_config(), - ) - vector_io_provider = Provider( - provider_id="faiss", - provider_type="inline::faiss", - config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), - ) - - inference_model = ModelInput( - model_id="${env.INFERENCE_MODEL}", - provider_id="meta-reference-inference", - ) - embedding_model = ModelInput( - model_id="all-MiniLM-L6-v2", - provider_id="sentence-transformers", - model_type=ModelType.embedding, - metadata={ - "embedding_dimension": 384, - }, - ) - return DistributionTemplate( - name=name, - distro_type="self_hosted", - description="Use Meta Reference with fp8, int4 quantization for running LLM inference", - template_path=Path(__file__).parent / "doc_template.md", - providers=providers, - run_configs={ - "run.yaml": RunConfigSettings( - provider_overrides={ - "inference": [inference_provider, embedding_provider], - "vector_io": [vector_io_provider], - }, - default_models=[inference_model, embedding_model], - default_tool_groups=default_tool_groups, - ), - }, - run_config_env_vars={ - "LLAMA_STACK_PORT": ( - "8321", - "Port for the Llama Stack distribution server", - ), - "INFERENCE_MODEL": ( - "meta-llama/Llama-3.2-3B-Instruct", - "Inference model loaded into the Meta Reference server", - ), - "INFERENCE_CHECKPOINT_DIR": ( - "null", - "Directory containing the Meta Reference model checkpoint", - ), - }, - ) diff --git a/llama_stack/templates/meta-reference-quantized-gpu/run.yaml b/llama_stack/templates/meta-reference-quantized-gpu/run.yaml deleted file mode 100644 index f934ecfbb..000000000 --- a/llama_stack/templates/meta-reference-quantized-gpu/run.yaml +++ /dev/null @@ -1,134 +0,0 @@ -version: '2' -image_name: meta-reference-quantized-gpu -apis: -- agents -- datasetio -- eval -- inference -- safety -- scoring -- telemetry -- tool_runtime -- vector_io -providers: - inference: - - provider_id: meta-reference-inference - provider_type: inline::meta-reference-quantized - config: - model: ${env.INFERENCE_MODEL} - max_seq_len: 4096 - checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} - quantization: - type: fp8 - - provider_id: sentence-transformers - provider_type: inline::sentence-transformers - config: {} - vector_io: - - provider_id: faiss - provider_type: inline::faiss - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/faiss_store.db - safety: - - provider_id: llama-guard - provider_type: inline::llama-guard - config: - excluded_categories: [] - agents: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - persistence_store: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/agents_store.db - telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - service_name: "${env.OTEL_SERVICE_NAME:\u200B}" - sinks: ${env.TELEMETRY_SINKS:console,sqlite} - sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-quantized-gpu/trace_store.db} - eval: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/meta_reference_eval.db - datasetio: - - provider_id: huggingface - provider_type: remote::huggingface - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/huggingface_datasetio.db - - provider_id: localfs - provider_type: inline::localfs - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/localfs_datasetio.db - scoring: - - provider_id: basic - provider_type: inline::basic - config: {} - - provider_id: llm-as-judge - provider_type: inline::llm-as-judge - config: {} - - provider_id: braintrust - provider_type: inline::braintrust - config: - openai_api_key: ${env.OPENAI_API_KEY:} - tool_runtime: - - provider_id: brave-search - provider_type: remote::brave-search - config: - api_key: ${env.BRAVE_SEARCH_API_KEY:} - max_results: 3 - - provider_id: tavily-search - provider_type: remote::tavily-search - config: - api_key: ${env.TAVILY_SEARCH_API_KEY:} - max_results: 3 - - provider_id: code-interpreter - provider_type: inline::code-interpreter - config: {} - - provider_id: rag-runtime - provider_type: inline::rag-runtime - config: {} - - provider_id: model-context-protocol - provider_type: remote::model-context-protocol - config: {} -metadata_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/registry.db -models: -- metadata: {} - model_id: ${env.INFERENCE_MODEL} - provider_id: meta-reference-inference - model_type: llm -- metadata: - embedding_dimension: 384 - model_id: all-MiniLM-L6-v2 - provider_id: sentence-transformers - model_type: embedding -shields: [] -vector_dbs: [] -datasets: [] -scoring_fns: [] -benchmarks: [] -tool_groups: -- toolgroup_id: builtin::websearch - provider_id: tavily-search -- toolgroup_id: builtin::rag - provider_id: rag-runtime -- toolgroup_id: builtin::code_interpreter - provider_id: code-interpreter -server: - port: 8321 diff --git a/pyproject.toml b/pyproject.toml index 8d8ff4338..8ae7ddbb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,9 +224,9 @@ exclude = [ "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^llama_stack/providers/inline/inference/meta_reference/config\\.py$", "^llama_stack/providers/inline/inference/meta_reference/inference\\.py$", - "^llama_stack/providers/inline/inference/meta_reference/llama3/generation\\.py$", - "^llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model\\.py$", - "^llama_stack/providers/inline/inference/meta_reference/llama4/", + "^llama_stack/models/llama/llama3/generation\\.py$", + "^llama_stack/models/llama/llama3/multimodal/model\\.py$", + "^llama_stack/models/llama/llama4/", "^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$", diff --git a/scripts/generate_prompt_format.py b/scripts/generate_prompt_format.py index 08c5bea22..5598e35f6 100755 --- a/scripts/generate_prompt_format.py +++ b/scripts/generate_prompt_format.py @@ -5,13 +5,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - # Run this script: # torchrun --nproc_per_node=8 scripts/generate_prompt_format.py meta-llama/Llama-4-17B-Omni-Instruct-BF16-16E ~/.llama/checkpoints/Llama-4-17B-Omni-Instruct-BF16-16E/ llama_stack.models.llama.llama4.prompts llama_stack/models/llama/llama4/prompt_format.md @@ -22,16 +15,9 @@ from pathlib import Path import fire +from llama_stack.models.llama.llama3.generation import Llama3 +from llama_stack.models.llama.llama4.generation import Llama4 from llama_stack.models.llama.sku_list import resolve_model -from llama_stack.providers.inline.inference.meta_reference.config import ( - MetaReferenceInferenceConfig, -) -from llama_stack.providers.inline.inference.meta_reference.llama3.generation import ( - Llama3, -) -from llama_stack.providers.inline.inference.meta_reference.llama4.generation import ( - Llama4, -) THIS_DIR = Path(__file__).parent.resolve() @@ -50,24 +36,12 @@ def run_main( if not llama_model: raise ValueError(f"Model {model_id} not found") - if not llama4: - config = MetaReferenceInferenceConfig( - model=model_id, - max_seq_len=4096, - max_batch_size=1, - checkpoint_dir=checkpoint_dir, - ) - generator = Llama3.build( - config=config, - model_id=model_id, - llama_model=llama_model, - ) - else: - generator = Llama4.build( - ckpt_dir=checkpoint_dir, - max_seq_len=4096, - max_batch_size=1, - ) + cls = Llama4 if llama4 else Llama3 + generator = cls.build( + ckpt_dir=checkpoint_dir, + max_seq_len=4096, + max_batch_size=1, + ) use_cases = module.usecases() text = "" diff --git a/tests/integration/report.py b/tests/integration/report.py index c07338ce6..a50f51d3f 100644 --- a/tests/integration/report.py +++ b/tests/integration/report.py @@ -11,7 +11,6 @@ import pytest from pytest import CollectReport from termcolor import cprint -from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_list import ( all_registered_models, llama3_1_instruct_models, @@ -20,6 +19,7 @@ from llama_stack.models.llama.sku_list import ( llama3_instruct_models, safety_models, ) +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.datatypes import Api from .metadata import API_MAPS