diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 567110829..f94f2a578 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4163,80 +4163,70 @@ ] }, "arguments": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - }, - { - "type": "number" - }, - { - "type": "boolean" - }, - { - "type": "null" - }, - { - "type": "array", - "items": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - }, - { - "type": "number" - }, - { - "type": "boolean" - }, - { - "type": "null" - } - ] + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "null" + }, + { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "null" } - }, - { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - }, - { - "type": "number" - }, - { - "type": "boolean" - }, - { - "type": "null" - } - ] + ] + } + }, + { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "null" } - } - ] + ] + } } - } - ] - }, - "arguments_json": { - "type": "string" + ] + } } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 1dfd17f55..238f8dcd0 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2890,34 +2890,30 @@ components: title: BuiltinTool - type: string arguments: - oneOf: - - type: string - - type: object - additionalProperties: - oneOf: - - type: string - - type: integer - - type: number - - type: boolean - - type: 'null' - - type: array - items: - oneOf: - - type: string - - type: integer - - type: number - - type: boolean - - type: 'null' - - type: object - additionalProperties: - oneOf: - - type: string - - type: integer - - type: number - - type: boolean - - type: 'null' - arguments_json: - type: string + type: object + additionalProperties: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + - type: array + items: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + - type: object + additionalProperties: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' additionalProperties: false required: - call_id diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 1d4012c19..216935ede 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -25,15 +25,64 @@ from llama_stack.apis.models import Model from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.models.llama.datatypes import ( BuiltinTool, - SamplingParams, StopReason, ToolCall, ToolDefinition, + ToolParamDefinition, ToolPromptFormat, ) from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, register_schema, webmethod +register_schema(ToolCall) +register_schema(ToolParamDefinition) +register_schema(ToolDefinition) + + +@json_schema_type +class GreedySamplingStrategy(BaseModel): + type: Literal["greedy"] = "greedy" + + +@json_schema_type +class TopPSamplingStrategy(BaseModel): + type: Literal["top_p"] = "top_p" + temperature: Optional[float] = Field(..., gt=0.0) + top_p: Optional[float] = 0.95 + + +@json_schema_type +class TopKSamplingStrategy(BaseModel): + type: Literal["top_k"] = "top_k" + top_k: int = Field(..., ge=1) + + +SamplingStrategy = Annotated[ + Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy], + Field(discriminator="type"), +] +register_schema(SamplingStrategy, name="SamplingStrategy") + + +@json_schema_type +class SamplingParams(BaseModel): + """Sampling parameters. + + :param strategy: The sampling strategy. + :param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of + your prompt plus max_tokens cannot exceed the model's context length. + :param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens + based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + :param stop: Up to 4 sequences where the API will stop generating further tokens. + The returned text will not contain the stop sequence. + """ + + strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy) + + max_tokens: Optional[int] = 0 + repetition_penalty: Optional[float] = 1.0 + stop: Optional[List[str]] = None + class LogProbConfig(BaseModel): """ diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py index fc3e7008f..9694bf22d 100644 --- a/llama_stack/cli/download.py +++ b/llama_stack/cli/download.py @@ -29,8 +29,8 @@ from rich.progress import ( from termcolor import cprint from llama_stack.cli.subcommand import Subcommand -from llama_stack.models.llama.datatypes import Model from llama_stack.models.llama.sku_list import LlamaDownloadInfo +from llama_stack.models.llama.sku_types import Model class Download(Subcommand): diff --git a/llama_stack/cli/model/describe.py b/llama_stack/cli/model/describe.py index f347bdf8d..62dde36e8 100644 --- a/llama_stack/cli/model/describe.py +++ b/llama_stack/cli/model/describe.py @@ -63,17 +63,6 @@ class ModelDescribe(Subcommand): ("Model params.json", json.dumps(model.arch_args, indent=4)), ] - if model.recommended_sampling_params is not None: - sampling_params = model.recommended_sampling_params.model_dump() - for k in ("max_tokens", "repetition_penalty"): - del sampling_params[k] - rows.append( - ( - "Recommended sampling params", - json.dumps(sampling_params, indent=4), - ) - ) - print_table( rows, headers, diff --git a/llama_stack/cli/model/prompt_format.py b/llama_stack/cli/model/prompt_format.py index 3ce77655b..673487812 100644 --- a/llama_stack/cli/model/prompt_format.py +++ b/llama_stack/cli/model/prompt_format.py @@ -11,7 +11,7 @@ from pathlib import Path from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.table import print_table -from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family +from llama_stack.models.llama.sku_types import CoreModelId, ModelFamily, is_multimodal, model_family ROOT_DIR = Path(__file__).parent.parent.parent diff --git a/llama_stack/cli/model/safety_models.py b/llama_stack/cli/model/safety_models.py index c81783f60..131d055aa 100644 --- a/llama_stack/cli/model/safety_models.py +++ b/llama_stack/cli/model/safety_models.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, Optional +from typing import Any, Dict from pydantic import BaseModel, ConfigDict, Field -from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams from llama_stack.models.llama.sku_list import LlamaDownloadInfo +from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat class PromptGuardModel(BaseModel): @@ -23,7 +23,6 @@ class PromptGuardModel(BaseModel): is_instruct_model: bool = False quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16 arch_args: Dict[str, Any] = Field(default_factory=dict) - recommended_sampling_params: Optional[SamplingParams] = None def descriptor(self) -> str: return self.model_id diff --git a/llama_stack/models/llama/checkpoint.py b/llama_stack/models/llama/checkpoint.py new file mode 100644 index 000000000..2bae08a69 --- /dev/null +++ b/llama_stack/models/llama/checkpoint.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import concurrent.futures +import re +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size + + +def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]: + """Map a new MP rank to a list of old MP ranks given a change in MP size.""" + if new_mp_size % old_mp_size == 0: + # Read old MP shard and split it into smaller ones + return [new_mp_rank * old_mp_size // new_mp_size] + elif old_mp_size % new_mp_size == 0: + # Merge old MP shards into a single one + mp_factor = old_mp_size // new_mp_size + return list(range(new_mp_rank * mp_factor, (new_mp_rank + 1) * mp_factor)) + else: + raise ValueError( + f"Either old MP size or new MP size should be a multiple of the other: " + f"{old_mp_size} % {new_mp_size} != 0 and {new_mp_size} % {old_mp_size} != 0" + ) + + +def maybe_reshard_state_dict( + ckpt_paths: List[Path], + n_kv_heads: int, + moe_num_experts: Optional[int] = None, + map_location: Union[str, torch.device] = "cpu", + mmap: bool = True, +) -> Dict[str, torch.Tensor]: + if str(map_location) == "cpu": + torch.set_default_tensor_type(torch.BFloat16Tensor) + else: + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + + ckpt_paths = np.array(sorted(ckpt_paths)) + + new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank() + old_mp_size = len(ckpt_paths) + old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank) + + print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore + paths = ckpt_paths[old_mp_ranks] # type: ignore + state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths] + + if new_mp_size == old_mp_size: + return state_dicts[0] # type: ignore + + if moe_num_experts is not None: + state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts] + + print(f"Resharding {len(state_dicts)} state dicts from MP size {old_mp_size} to MP size {new_mp_size}") + return reshard_mp( + state_dicts, + size=max(new_mp_size // old_mp_size, 1), + rank=new_mp_rank % max(new_mp_size // old_mp_size, 1), + repeat_qk_qv=max(new_mp_size // n_kv_heads, 1), + ) + + +_WEIGHT_ROW_KEY = { + "feed_forward.w2", + "feed_forward.mlp.fc2", + "attention.wo", + "feed_forward.mlp.fc2_weight", + "feed_forward.w_out_shared_DF.weight", + "attn.wo.weight", + "mlp.c_proj.weight", +} +_MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"} + +_WEIGHT_COLUMN_KEY = { + "output", + "feed_forward.(w1|w3)", + "feed_forward.mlp.(fc1|fc3)", + "feed_forward.mlp.fc1_weight", + "attention.(wk|wq|wv|wqkv).weight", + "feed_forward.(w_in_shared_FD|w_swiglu_FD)", + "attn.(wk|wq|wv).weight", + "attn.(wk|wq|wv).bias", + "mlp.c_fc.weight", + "mlp.c_fc.bias", + "conv1._linear.weight", + "tok_embeddings.weight", + "vision_projection.weight", +} +_MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"} + + +def reshard_mp( + state_dicts: List[Dict[str, torch.Tensor]], + size: int, + rank: int, + repeat_qk_qv: int = 1, +) -> Dict[str, torch.Tensor]: + """ + Reshard a list of state dicts into a single state dict given a change in MP size. + If the list has more than one state dict, we concatenate the values of the same + key across all state dicts. Otherwise, we just slice it for the current MP rank. + """ + + def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor: + if len(tensors) > 1: + return torch.cat(tensors, dim=dim) + return tensors[0].chunk(size, dim=dim)[rank].clone() + + def process_key(key: str) -> torch.Tensor: + if row_regex.search(key): + return concat_or_chunk([s[key] for s in state_dicts], dim=-1) + elif column_regex.search(key): + if "w13" in key or "fc1_weight" in key: + dims = state_dicts[0][key].size() + values = [s[key].view(2, dims[0] // 2, *dims[1:]) for s in state_dicts] + return concat_or_chunk(values, dim=1).flatten(0, 1) + elif "qkv" in key: + q_dim = state_dicts[0][key.replace("qkv", "o")].size(1) + kv_dim = (state_dicts[0][key].size(0) - q_dim) // 2 + values = [s[key].split((q_dim, kv_dim, kv_dim)) for s in state_dicts] + return torch.cat([concat_or_chunk(x, dim=0) for x in zip(*values, strict=False)]) # type: ignore + elif "wk.weight" in key or "wv.weight" in key: + # Support MP > #kv_head + return concat_or_chunk([s[key].repeat(repeat_qk_qv, 1) for s in state_dicts], dim=0) + elif key == "output.bias" or key == "fc.weight": + return concat_or_chunk([s[key] for s in state_dicts], dim=0) + elif "w_" in key: + return concat_or_chunk([s[key] for s in state_dicts], dim=-2) + else: + return concat_or_chunk([s[key] for s in state_dicts], dim=0) + else: + return state_dicts[0][key].clone() + + row_keys = _WEIGHT_ROW_KEY | _MOE_WEIGHT_ROW_KEY + column_keys = _WEIGHT_COLUMN_KEY | _MOE_WEIGHT_COLUMN_KEY + + column_regex = re.compile("|".join(column_keys)) + row_regex = re.compile("|".join(row_keys)) + + output: Dict[str, torch.Tensor] = {} + with concurrent.futures.ThreadPoolExecutor() as executor: + # Note: only processes keys in the first state dict. + # Assumes keys are the same across all state dicts. + mappings = {executor.submit(process_key, key): key for key in state_dicts[0]} + for future in concurrent.futures.as_completed(mappings): + output[mappings[future]] = future.result() + return output + + +def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]: + routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY + routed_regex = re.compile("|".join(routed_keys)) + keys = list(state_dict.keys()) + for key in keys: + if routed_regex.search(key): + state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0) + return state_dict diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py index ef791da8f..106875bb2 100644 --- a/llama_stack/models/llama/datatypes.py +++ b/llama_stack/models/llama/datatypes.py @@ -4,13 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - import base64 from enum import Enum from io import BytesIO @@ -19,8 +12,6 @@ from typing import Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator from typing_extensions import Annotated -from llama_stack.schema_utils import json_schema_type, register_schema - # The goal is that these set of types are relevant for all Llama models. # That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to # the llama3 series of models. @@ -47,14 +38,7 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]] class ToolCall(BaseModel): call_id: str tool_name: Union[BuiltinTool, str] - # Plan is to deprecate the Dict in favor of a JSON string - # that is parsed on the client side instead of trying to manage - # the recursive type here. - # Making this a union so that client side can start prepping for this change. - # Eventually, we will remove both the Dict and arguments_json field, - # and arguments will just be a str - arguments: Union[str, Dict[str, RecursiveType]] - arguments_json: Optional[str] = None + arguments: Dict[str, RecursiveType] @field_validator("tool_name", mode="before") @classmethod @@ -98,6 +82,29 @@ class StopReason(Enum): out_of_tokens = "out_of_tokens" +class ToolParamDefinition(BaseModel): + param_type: str + description: Optional[str] = None + required: Optional[bool] = True + default: Optional[Any] = None + + +class ToolDefinition(BaseModel): + tool_name: Union[BuiltinTool, str] + description: Optional[str] = None + parameters: Optional[Dict[str, ToolParamDefinition]] = None + + @field_validator("tool_name", mode="before") + @classmethod + def validate_field(cls, v): + if isinstance(v, str): + try: + return BuiltinTool(v) + except ValueError: + return v + return v + + class RawMediaItem(BaseModel): type: Literal["image"] = "image" data: bytes | BytesIO @@ -140,292 +147,25 @@ class RawMessage(BaseModel): tool_calls: List[ToolCall] = Field(default_factory=list) -register_schema(ToolCall) +class GenerationResult(BaseModel): + token: int + text: str + logprobs: Optional[List[float]] = None + + source: Literal["input"] | Literal["output"] + + # index within the batch + batch_idx: int + # whether generation for this item is already finished. note that tokens can + # get returned even afterwards since other items in the batch can still be generating tokens + finished: bool + # because a batch is parallel processed, useful decoding for one item can correspond to processing + # pad tokens or tokens beyond EOS for other items. we could have decided to return None for this case + # but it's more convenient to return a list of GenerationResult and filter out the ignored tokens + ignore_token: bool -@json_schema_type -class ToolParamDefinition(BaseModel): - param_type: str - description: Optional[str] = None - required: Optional[bool] = True - default: Optional[Any] = None - - -@json_schema_type -class ToolDefinition(BaseModel): - tool_name: Union[BuiltinTool, str] - description: Optional[str] = None - parameters: Optional[Dict[str, ToolParamDefinition]] = None - - @field_validator("tool_name", mode="before") - @classmethod - def validate_field(cls, v): - if isinstance(v, str): - try: - return BuiltinTool(v) - except ValueError: - return v - return v - - -@json_schema_type -class GreedySamplingStrategy(BaseModel): - type: Literal["greedy"] = "greedy" - - -@json_schema_type -class TopPSamplingStrategy(BaseModel): - type: Literal["top_p"] = "top_p" - temperature: Optional[float] = Field(..., gt=0.0) - top_p: Optional[float] = 0.95 - - -@json_schema_type -class TopKSamplingStrategy(BaseModel): - type: Literal["top_k"] = "top_k" - top_k: int = Field(..., ge=1) - - -SamplingStrategy = Annotated[ - Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy], - Field(discriminator="type"), -] -register_schema(SamplingStrategy, name="SamplingStrategy") - - -@json_schema_type -class SamplingParams(BaseModel): - """Sampling parameters. - - :param strategy: The sampling strategy. - :param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of - your prompt plus max_tokens cannot exceed the model's context length. - :param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens - based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. - :param stop: Up to 4 sequences where the API will stop generating further tokens. - The returned text will not contain the stop sequence. - """ - - strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy) - - max_tokens: Optional[int] = 0 - repetition_penalty: Optional[float] = 1.0 - stop: Optional[List[str]] = None - - -class CheckpointQuantizationFormat(Enum): - # default format - bf16 = "bf16" - - # used for enabling fp8_rowwise inference, some weights are bf16 - fp8_mixed = "fp8-mixed" - - int8 = "int8" - - int4 = "int4" - - -class ModelFamily(Enum): - llama2 = "llama2" - llama3 = "llama3" - llama3_1 = "llama3_1" - llama3_2 = "llama3_2" - llama3_3 = "llama3_3" - llama4 = "llama4" - safety = "safety" - - -class CoreModelId(Enum): - """Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)""" - - # Llama 2 family - llama2_7b = "Llama-2-7b" - llama2_13b = "Llama-2-13b" - llama2_70b = "Llama-2-70b" - llama2_7b_chat = "Llama-2-7b-chat" - llama2_13b_chat = "Llama-2-13b-chat" - llama2_70b_chat = "Llama-2-70b-chat" - - # Llama 3 family - llama3_8b = "Llama-3-8B" - llama3_70b = "Llama-3-70B" - llama3_8b_instruct = "Llama-3-8B-Instruct" - llama3_70b_instruct = "Llama-3-70B-Instruct" - - # Llama 3.1 family - llama3_1_8b = "Llama3.1-8B" - llama3_1_70b = "Llama3.1-70B" - llama3_1_405b = "Llama3.1-405B" - llama3_1_8b_instruct = "Llama3.1-8B-Instruct" - llama3_1_70b_instruct = "Llama3.1-70B-Instruct" - llama3_1_405b_instruct = "Llama3.1-405B-Instruct" - - # Llama 3.2 family - llama3_2_1b = "Llama3.2-1B" - llama3_2_3b = "Llama3.2-3B" - llama3_2_1b_instruct = "Llama3.2-1B-Instruct" - llama3_2_3b_instruct = "Llama3.2-3B-Instruct" - llama3_2_11b_vision = "Llama3.2-11B-Vision" - llama3_2_90b_vision = "Llama3.2-90B-Vision" - llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct" - llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct" - - # Llama 3.3 family - llama3_3_70b_instruct = "Llama3.3-70B-Instruct" - - # Llama 4 family - llama4_scout_17b_16e = "Llama-4-Scout-17B-16E" - llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct" - llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E" - llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct" - - # Safety models - llama_guard_3_8b = "Llama-Guard-3-8B" - llama_guard_2_8b = "Llama-Guard-2-8B" - llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision" - llama_guard_3_1b = "Llama-Guard-3-1B" - - -def is_multimodal(model_id) -> bool: - if model_id in [ - CoreModelId.llama3_2_11b_vision, - CoreModelId.llama3_2_90b_vision, - CoreModelId.llama3_2_11b_vision_instruct, - CoreModelId.llama3_2_90b_vision_instruct, - ]: - return True - else: - return False - - -def model_family(model_id) -> ModelFamily: - if model_id in [ - CoreModelId.llama2_7b, - CoreModelId.llama2_13b, - CoreModelId.llama2_70b, - CoreModelId.llama2_7b_chat, - CoreModelId.llama2_13b_chat, - CoreModelId.llama2_70b_chat, - ]: - return ModelFamily.llama2 - elif model_id in [ - CoreModelId.llama3_8b, - CoreModelId.llama3_70b, - CoreModelId.llama3_8b_instruct, - CoreModelId.llama3_70b_instruct, - ]: - return ModelFamily.llama3 - elif model_id in [ - CoreModelId.llama3_1_8b, - CoreModelId.llama3_1_70b, - CoreModelId.llama3_1_405b, - CoreModelId.llama3_1_8b_instruct, - CoreModelId.llama3_1_70b_instruct, - CoreModelId.llama3_1_405b_instruct, - ]: - return ModelFamily.llama3_1 - elif model_id in [ - CoreModelId.llama3_2_1b, - CoreModelId.llama3_2_3b, - CoreModelId.llama3_2_1b_instruct, - CoreModelId.llama3_2_3b_instruct, - CoreModelId.llama3_2_11b_vision, - CoreModelId.llama3_2_90b_vision, - CoreModelId.llama3_2_11b_vision_instruct, - CoreModelId.llama3_2_90b_vision_instruct, - ]: - return ModelFamily.llama3_2 - elif model_id in [ - CoreModelId.llama3_3_70b_instruct, - ]: - return ModelFamily.llama3_3 - elif model_id in [ - CoreModelId.llama4_scout_17b_16e, - CoreModelId.llama4_scout_17b_16e_instruct, - CoreModelId.llama4_maverick_17b_128e, - CoreModelId.llama4_maverick_17b_128e_instruct, - ]: - return ModelFamily.llama4 - elif model_id in [ - CoreModelId.llama_guard_3_8b, - CoreModelId.llama_guard_2_8b, - CoreModelId.llama_guard_3_11b_vision, - CoreModelId.llama_guard_3_1b, - ]: - return ModelFamily.safety - else: - raise ValueError(f"Unknown model family for {model_id}") - - -class Model(BaseModel): - core_model_id: CoreModelId - description: str - huggingface_repo: Optional[str] = None - recommended_sampling_params: Optional[SamplingParams] = None - arch_args: Dict[str, Any] - variant: str = "" - - quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16 - pth_file_count: int - metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) - - # silence pydantic until we remove the `model_` fields - model_config = ConfigDict(protected_namespaces=()) - - @property - def model_family(self) -> ModelFamily: - return model_family(self.core_model_id) - - # The SKU is uniquely identified by (model_id, variant) combo - def descriptor(self, shorten_default_variant: bool = True) -> str: - if not self.variant: - return self.core_model_id.value - return f"{self.core_model_id.value}:{self.variant}" - - @property - def is_instruct_model(self) -> bool: - return "instruct" in self.id.name - - # Featured models are shown in the non-exhaustive model list - @property - def is_featured(self) -> bool: - return self.model_family in [ - ModelFamily.llama3_1, - ModelFamily.llama3_2, - ModelFamily.llama3_3, - ModelFamily.llama4, - ModelFamily.safety, - ] - - @property - def max_seq_length(self) -> int: - if self.model_family == ModelFamily.llama2: - return 4096 - elif self.core_model_id == CoreModelId.llama_guard_2_8b: - return 4096 - elif self.model_family == ModelFamily.llama3: - return 8192 - elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]: - return 131072 - elif self.model_family == ModelFamily.llama3_2: - if self.quantization_format == CheckpointQuantizationFormat.int4: - return 8192 - return 131072 - elif self.model_family == ModelFamily.llama4: - if self.core_model_id in { - CoreModelId.llama4_scout_17b_16e, - CoreModelId.llama4_maverick_17b_128e, - }: - return 262144 - if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct: - return 10485760 - if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct: - return 1048576 - elif self.core_model_id in [ - CoreModelId.llama_guard_3_8b, - CoreModelId.llama_guard_3_11b_vision, - CoreModelId.llama_guard_3_1b, - ]: - return 131072 - else: - raise ValueError(f"Unknown max_seq_len for {self.core_model_id}") +class QuantizationMode(str, Enum): + none = "none" + fp8_mixed = "fp8_mixed" + int4_mixed = "int4_mixed" diff --git a/llama_stack/models/llama/llama3/args.py b/llama_stack/models/llama/llama3/args.py index e96eaca61..f7e4b4557 100644 --- a/llama_stack/models/llama/llama3/args.py +++ b/llama_stack/models/llama/llama3/args.py @@ -4,13 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - from dataclasses import dataclass from enum import Enum from typing import Optional diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py index 8ae911fc3..f55cd5e1c 100644 --- a/llama_stack/models/llama/llama3/chat_format.py +++ b/llama_stack/models/llama/llama3/chat_format.py @@ -4,13 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - import io import json import uuid diff --git a/llama_stack/models/llama/llama3/generation.py b/llama_stack/models/llama/llama3/generation.py index b4e0d39b9..ee99a07ba 100644 --- a/llama_stack/models/llama/llama3/generation.py +++ b/llama_stack/models/llama/llama3/generation.py @@ -4,59 +4,37 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - # Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. import json import os import sys import time -from dataclasses import dataclass from pathlib import Path from typing import Callable, Generator, List, Optional import torch import torch.nn.functional as F from fairscale.nn.model_parallel.initialize import ( - get_model_parallel_rank, initialize_model_parallel, model_parallel_is_initialized, ) from termcolor import cprint -from ..datatypes import RawContent, RawMessage, StopReason, ToolPromptFormat +from ..checkpoint import maybe_reshard_state_dict +from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat from .args import ModelArgs from .chat_format import ChatFormat, LLMInput from .model import Transformer +from .multimodal.model import CrossAttentionTransformer from .tokenizer import Tokenizer -@dataclass -class CompletionPrediction: - generation: str - decoded_tokens: Optional[List[str]] = None - logprobs: Optional[List[List[float]]] = None - - -@dataclass -class ChatPrediction: - generation: RawMessage - decoded_tokens: Optional[List[str]] = None - logprobs: Optional[List[List[float]]] = None - - -@dataclass -class TokenResult: - token: int - text: str - logprobs: Optional[List[float]] = None - - -# TODO: make this completely parallel to the llama4 generation.py file and share common code -# from llama-models also class Llama3: @staticmethod def build( @@ -64,7 +42,7 @@ class Llama3: max_seq_len: int, max_batch_size: int, world_size: Optional[int] = None, - tokenizer_path: Optional[str] = None, + quantization_mode: Optional[QuantizationMode] = None, seed: int = 1, device: str = "cuda", ): @@ -101,13 +79,9 @@ class Llama3: start_time = time.time() - checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) - assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" - assert world_size == len(checkpoints), ( - f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}" - ) - ckpt_path = checkpoints[get_model_parallel_rank()] - checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) + ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth")) + assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}" + print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})") with open(Path(ckpt_dir) / "params.json", "r") as f: params = json.loads(f.read()) @@ -116,40 +90,58 @@ class Llama3: max_batch_size=max_batch_size, **params, ) - if tokenizer_path: - tokenizer = Tokenizer(model_path=tokenizer_path) - else: - tokenizer = Tokenizer.get_instance() + tokenizer = Tokenizer.get_instance() + + state_dict = maybe_reshard_state_dict( + ckpt_paths, + n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads, + ) assert model_args.vocab_size == tokenizer.n_words - torch.set_default_device(device) - if device.type == "cuda": - if torch.cuda.is_bf16_supported(): - torch.set_default_dtype(torch.bfloat16) - else: - torch.set_default_dtype(torch.half) - elif device.type == "xpu": - if torch.xpu.is_bf16_supported(): - torch.set_default_dtype(torch.bfloat16) - else: - torch.set_default_dtype(torch.half) - else: - torch.set_default_dtype(torch.half) - if model_args.vision_chunk_size > 0: - from .multimodal.model import CrossAttentionTransformer + def build_model(): + if model_args.vision_chunk_size > 0: + model = CrossAttentionTransformer(model_args) + model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype()) + else: + model = Transformer(model_args) + return model - model = CrossAttentionTransformer(model_args) - model.setup_cache(model_args.max_batch_size, torch.get_default_dtype()) + if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed: + from .quantization.loader import convert_to_quantized_model + + torch.set_default_tensor_type(torch.BFloat16Tensor) + model = build_model() + print("Loading state dict...") + model.load_state_dict(state_dict, strict=False) + print("Done...") + model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device) + torch.set_default_device(device) else: - model = Transformer(model_args) - model.load_state_dict(checkpoint, strict=True) - model.to(device) + print(f"Setting default device to {device}") + torch.set_default_device(device) + if device.type == "cuda": + if torch.cuda.is_bf16_supported(): + torch.set_default_dtype(torch.bfloat16) + else: + torch.set_default_dtype(torch.half) + elif device.type == "xpu": + if torch.xpu.is_bf16_supported(): + torch.set_default_dtype(torch.bfloat16) + else: + torch.set_default_dtype(torch.half) + + model = build_model() + print("Loading state dict...") + model.load_state_dict(state_dict, strict=True) + model.to(device) + print("Done...") + print(f"Loaded in {time.time() - start_time:.2f} seconds") - return Llama(model, tokenizer, model_args) + return Llama3(model, tokenizer, model_args) - def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs): + def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs): self.args = args self.model = model self.tokenizer = tokenizer @@ -158,26 +150,30 @@ class Llama3: @torch.inference_mode() def generate( self, - model_input: LLMInput, - max_gen_len: int, + model_inputs: List[LLMInput], temperature: float = 0.6, top_p: float = 0.9, + max_gen_len: Optional[int] = None, logprobs: bool = False, echo: bool = False, print_model_input: bool = False, logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - ) -> Generator: + ) -> Generator[List[GenerationResult], None, None]: + if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: + max_gen_len = self.args.max_seq_len - 1 params = self.model.params + print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1" if print_model_input: - tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens] - cprint( - "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n", - "red", - ) - prompt_tokens = [model_input.tokens] + for inp in model_inputs: + tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens] + cprint( + "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n", + "red", + ) + prompt_tokens = [inp.tokens for inp in model_inputs] - bsz = 1 + bsz = len(model_inputs) assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) min_prompt_len = min(len(t) for t in prompt_tokens) @@ -189,18 +185,6 @@ class Llama3: total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) - is_vision = not isinstance(self.model, Transformer) - if is_vision: - images = model_input.vision.images if model_input.vision is not None else [] - mask = model_input.vision.mask if model_input.vision is not None else [] - - # the method works for bsz > 1 so add a batch dimension - xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( - batch_images=[images], - batch_masks=[mask], - total_len=total_len, - ) - pad_id = self.tokenizer.pad_id tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long) for k, t in enumerate(prompt_tokens): @@ -208,23 +192,45 @@ class Llama3: if logprobs: token_logprobs = torch.zeros_like(tokens, dtype=torch.float) - prev_pos = 0 + is_vision = not isinstance(self.model, Transformer) + if is_vision: + images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs] + mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs] + + xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( + batch_images=images, + batch_masks=mask, + total_len=total_len, + device=tokens.device, + ) + eos_reached = torch.tensor([False] * bsz) input_text_mask = tokens != pad_id if echo: - for i, t in enumerate(model_input.tokens): - yield TokenResult( - token=t, - text=self.tokenizer.decode([t]), - logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None), - ) + for i in range(max_prompt_len): + results = [] + for j, t in enumerate(tokens[:, i]): + results.append( + GenerationResult( + token=t.item(), + text=self.tokenizer.decode([t.item()]), + source="input", + logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None), + batch_idx=j, + finished=False, + ignore_token=t.item() == pad_id, + ) + ) + yield results stop_tokens = torch.tensor(self.tokenizer.stop_tokens) + + prev_pos = 0 for cur_pos in range(min_prompt_len, total_len): if is_vision: position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long) - text_only_inference = model_input.vision is None + text_only_inference = all(inp.vision is None for inp in model_inputs) logits = self.model.forward( position_ids, tokens, @@ -271,155 +277,69 @@ class Llama3: ignore_index=pad_id, ) eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens)) - yield TokenResult( - token=next_token[0].item(), - text=self.tokenizer.decode(next_token.tolist()), - logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None), - ) + results = [] + for idx, t in enumerate(next_token): + results.append( + GenerationResult( + token=t.item(), + text=self.tokenizer.decode([t.item()]), + source="output", + logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None), + batch_idx=idx, + finished=eos_reached[idx], + ignore_token=cur_pos < len(prompt_tokens[idx]), + ) + ) + yield results prev_pos = cur_pos if all(eos_reached): break - def text_completion( + def completion( self, - content: RawContent, + contents: List[RawContent], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, echo: bool = False, - ) -> CompletionPrediction: - if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len: - max_gen_len = self.model.params.max_seq_len - 1 - - model_input = self.formatter.encode_content(content) - - tokens = [] - token_logprobs = [] - decoded_tokens = [] + ) -> Generator[List[GenerationResult], None, None]: + model_inputs = [self.formatter.encode_content(c) for c in contents] for result in self.generate( - model_input=model_input, - max_gen_len=max_gen_len, + model_inputs=model_inputs, temperature=temperature, top_p=top_p, + max_gen_len=max_gen_len, logprobs=logprobs, echo=echo, ): - tokens.append(result.token) - if logprobs: - decoded_tokens.append(result.text) - token_logprobs.append(result.logprobs) - - generation = self.tokenizer.decode(tokens) - if logprobs: - return CompletionPrediction( - generation=generation, - logprobs=token_logprobs, - decoded_tokens=decoded_tokens, - ) - - return CompletionPrediction(generation=generation) + yield result + if all(r.finished for r in result): + break def chat_completion( self, - messages: List[RawMessage], + messages_batch: List[List[RawMessage]], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, echo: bool = False, - ) -> ChatPrediction: - if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len: - max_gen_len = self.model.params.max_seq_len - 1 - - tokens = [] - token_logprobs = [] - decoded_tokens = [] - - stop_reason = None + ) -> Generator[List[GenerationResult], None, None]: + model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch] for result in self.generate( - model_input=self.formatter.encode_dialog_prompt(messages, tool_prompt_format), - max_gen_len=max_gen_len, + model_inputs=model_inputs, temperature=temperature, top_p=top_p, + max_gen_len=max_gen_len, logprobs=logprobs, echo=echo, ): - tokens.append(result.token) - if result.text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - elif result.text == "<|eom_id|>": - stop_reason = StopReason.end_of_message - - if logprobs: - decoded_tokens.append(result.text) - token_logprobs.append(result.logprobs) - - if stop_reason is None: - stop_reason = StopReason.out_of_tokens - - message = self.formatter.decode_assistant_message(tokens, stop_reason) - - if logprobs: - return ChatPrediction( - generation=message, - logprobs=token_logprobs, - decoded_tokens=decoded_tokens, - ) - - return ChatPrediction(generation=message) - - def chat_completion_raw( - self, - messages: List[RawMessage], - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, - ) -> List[int]: - if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len: - max_gen_len = self.model.params.max_seq_len - 1 - - output_tokens = [] - model_input = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) - input_tokens = model_input.tokens - for result in self.generate( - model_input=model_input, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=False, - ): - output_tokens.append(result.token) - - return input_tokens, output_tokens - - def text_completion_raw( - self, - content: RawContent, - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - ): - if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len: - max_gen_len = self.model.params.max_seq_len - 1 - - model_input = self.formatter.encode_content(content) - input_tokens = model_input.tokens - - output_tokens = [] - for result in self.generate( - model_input=model_input, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - logprobs=False, - ): - output_tokens.append(result.token) - - return input_tokens, output_tokens + yield result + if all(r.finished for r in result): + break def sample_top_p(probs, p): diff --git a/llama_stack/models/llama/llama3/model.py b/llama_stack/models/llama/llama3/model.py index a49167980..2562673e2 100644 --- a/llama_stack/models/llama/llama3/model.py +++ b/llama_stack/models/llama/llama3/model.py @@ -4,16 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. - import math from typing import Optional, Tuple @@ -29,6 +19,10 @@ from torch import nn from .args import ModelArgs +# **NOTE**: This code is not runnable without installing `torch` and `fairscale` +# dependencies. These dependencies are not part of the default dependencies +# (requirements.txt) of the `llama-models` package. + class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): @@ -111,9 +105,9 @@ class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - model_parallel_size = fs_init.get_model_parallel_world_size() - self.n_local_heads = args.n_heads // model_parallel_size - self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + world_size = fs_init.get_model_parallel_world_size() + self.n_local_heads = args.n_heads // world_size + self.n_local_kv_heads = self.n_kv_heads // world_size self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads diff --git a/llama_stack/models/llama/llama3/multimodal/model.py b/llama_stack/models/llama/llama3/multimodal/model.py index 3d0d77c87..0cb18b948 100644 --- a/llama_stack/models/llama/llama3/multimodal/model.py +++ b/llama_stack/models/llama/llama3/multimodal/model.py @@ -4,16 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - import logging import math from functools import partial @@ -180,14 +170,14 @@ class ImageAttention(nn.Module): n_heads, ): super().__init__() - model_parallel_size = fs_init.get_model_parallel_world_size() + world_size = fs_init.get_model_parallel_world_size() qkvo_replication = 1 - if model_parallel_size > 16: - qkvo_replication = model_parallel_size // 8 + if world_size > 16: + qkvo_replication = world_size // 8 self.n_kv_heads = n_heads - self.n_local_heads = n_heads * qkvo_replication // model_parallel_size - self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // model_parallel_size + self.n_local_heads = n_heads * qkvo_replication // world_size + self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // world_size self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = dim // n_heads @@ -536,16 +526,16 @@ class Attention(nn.Module): cache_v (torch.Tensor): Cached values for attention. """ super().__init__() - model_parallel_size = fs_init.get_model_parallel_world_size() + world_size = fs_init.get_model_parallel_world_size() replication_factor = 1 - if model_parallel_size > 8: - replication_factor = model_parallel_size // MP_SCALE + if world_size > 8: + replication_factor = world_size // MP_SCALE self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads self.n_kv_heads *= replication_factor - self.n_local_heads = args.n_heads // model_parallel_size - self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_local_heads = args.n_heads // world_size + self.n_local_kv_heads = self.n_kv_heads // world_size self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads self.max_seq_len = args.max_seq_len @@ -587,13 +577,11 @@ class Attention(nn.Module): self.n_local_kv_heads, self.head_dim, ) - device = next(self.parameters()).device self.register_buffer( "key_cache", torch.zeros( cache_shape, dtype=dtype, - device=device, ), persistent=False, ) @@ -602,7 +590,6 @@ class Attention(nn.Module): torch.zeros( cache_shape, dtype=dtype, - device=device, ), persistent=False, ) @@ -614,6 +601,9 @@ class Attention(nn.Module): freqs_cis: torch.Tensor, position_ids: torch.LongTensor, ): + self.key_cache = self.key_cache.to(x.device) + self.value_cache = self.value_cache.to(x.device) + xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]] bs, slen, _ = xq.shape @@ -832,10 +822,10 @@ class CrossAttention(torch.nn.Module): norm_eps: float, ): super().__init__() - self.model_parallel_size = fs_init.get_model_parallel_world_size() + self.world_size = fs_init.get_model_parallel_world_size() replication_factor = 1 - if self.model_parallel_size > 8: - replication_factor = self.model_parallel_size // MP_SCALE + if self.world_size > 8: + replication_factor = self.world_size // MP_SCALE n_kv_heads *= replication_factor assert n_heads % n_kv_heads == 0 @@ -889,10 +879,10 @@ class CrossAttention(torch.nn.Module): # trunk LLM (i.e., group query attention) -- @dubeya # local heads assert self.n_heads % self.n_kv_heads == 0 - assert self.n_heads % self.model_parallel_size == 0 - assert self.n_kv_heads % self.model_parallel_size == 0 - self.n_local_heads = self.n_heads // self.model_parallel_size - self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size + assert self.n_heads % self.world_size == 0 + assert self.n_kv_heads % self.world_size == 0 + self.n_local_heads = self.n_heads // self.world_size + self.n_local_kv_heads = self.n_kv_heads // self.world_size self.n_rep = self.n_local_heads // self.n_local_kv_heads def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: @@ -1041,7 +1031,7 @@ class CrossAttentionTransformerVision(torch.nn.Module): self.image_res = args.vision_chunk_size self.max_num_chunks = args.vision_max_num_chunks if return_intermediate is not None: - return_intermediate = [int(level) for level in return_intermediate.split(",")] + return_intermediate = [int(layer) for layer in return_intermediate.split(",")] self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim self.patch_size = 14 self.vision_encoder = VisionEncoder( @@ -1076,15 +1066,15 @@ class CrossAttentionTransformerText(torch.nn.Module): def __init__(self, args: ModelArgs) -> None: super().__init__() - self.model_parallel_size = fs_init.get_model_parallel_world_size() + self.world_size = fs_init.get_model_parallel_world_size() assert args.vocab_size > 0 self.vocab_size = args.vocab_size self.n_layers = args.n_layers self.dim = args.dim self.head_dim = args.dim // args.n_heads self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size - assert self.vocab_size % self.model_parallel_size == 0 + self.n_local_kv_heads = self.n_kv_heads // self.world_size + assert self.vocab_size % self.world_size == 0 self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x) self.pos_embeddings = None # final norm layer (not necessary for post-norm) @@ -1184,6 +1174,8 @@ class CrossAttentionTransformerText(torch.nn.Module): text_only_inference: bool = False, ): assert self.cache_is_setup, "Please set up cache before calling forward" + self.mask_cache = self.mask_cache.to(h.device) + self.freqs_cis = self.freqs_cis.to(h.device) mask = self.mask_cache.index_select(2, position_ids) freqs_cis = self.freqs_cis.index_select(0, position_ids) @@ -1212,9 +1204,8 @@ class CrossAttentionTransformerText(torch.nn.Module): output = gather_from_tensor_model_parallel_region(output) return output.float() - def setup_cache(self, max_batch_size: int, dtype=torch.bfloat16): + def setup_cache(self, max_batch_size: int, device: torch.device, dtype=torch.bfloat16): # Set up the text kv caches - device = next(self.parameters()).device ones = torch.ones( (self.max_seq_len, self.max_seq_len), dtype=torch.bool, @@ -1265,7 +1256,7 @@ class CrossAttentionTransformerText(torch.nn.Module): return ( cross_attention_masks.to(device=text_device, dtype=text_dtype), - full_text_row_masked_out_mask, + full_text_row_masked_out_mask.to(device=text_device), ) @@ -1284,14 +1275,15 @@ class CrossAttentionTransformer(torch.nn.Module): max_num_chunks=args.vision_max_num_chunks, ) - def setup_cache(self, max_batch_size: int, dtype: torch.dtype): - self.text_model.setup_cache(max_batch_size, dtype) + def setup_cache(self, max_batch_size: int, device: torch.device, dtype: torch.dtype): + self.text_model.setup_cache(max_batch_size, device, dtype) def compute_vision_tokens_masks( self, batch_images: List[List[PIL_Image.Image]], batch_masks: List[List[List[int]]], total_len: int, + device: torch.device, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: skip_vision_encoder = False @@ -1318,6 +1310,7 @@ class CrossAttentionTransformer(torch.nn.Module): image_res=self.params.vision_chunk_size, max_num_images=max_num_images, ) + stacked_images = stacked_images.to(device=device) if skip_vision_encoder: vision_tokens = torch.zeros( @@ -1330,7 +1323,7 @@ class CrossAttentionTransformer(torch.nn.Module): ), ) else: - vision_tokens = self.vision_model(stacked_images, aspect_ratios) + vision_tokens = self.vision_model(stacked_images, aspect_ratios).to(device=device) bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape) xattn_caches = torch.stack( diff --git a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py index e03fcfc93..d4e825a22 100644 --- a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py +++ b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py @@ -15,7 +15,7 @@ import textwrap from datetime import datetime from typing import Any, List, Optional -from llama_stack.models.llama.datatypes import ( +from llama_stack.apis.inference import ( BuiltinTool, ToolDefinition, ToolParamDefinition, diff --git a/llama_stack/models/llama/llama3/quantization/loader.py b/llama_stack/models/llama/llama3/quantization/loader.py index f4d94c382..771fd02be 100644 --- a/llama_stack/models/llama/llama3/quantization/loader.py +++ b/llama_stack/models/llama/llama3/quantization/loader.py @@ -4,9 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. - # type: ignore import os from typing import Any, Dict, List, Optional, cast @@ -18,22 +15,15 @@ from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_regi from torch import Tensor, nn from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear -from llama_stack.apis.inference import QuantizationType -from llama_stack.log import get_logger -from llama_stack.models.llama.sku_list import resolve_model - -from ...config import MetaReferenceQuantizedInferenceConfig -from ...datatypes import CheckpointQuantizationFormat +from ...datatypes import QuantizationMode from ...quantize_impls import ( Fp8ScaledWeights, ffn_swiglu, load_fp8, quantize_fp8, ) -from ..args import ModelArgs from ..model import Transformer, TransformerBlock - -log = get_logger(__name__, category="quantization") +from ..multimodal.model import CrossAttentionTransformer def swiglu_wrapper( @@ -44,30 +34,34 @@ def swiglu_wrapper( return reduce_from_model_parallel_region(out) +def convert_to_quantized_model( + model: Transformer | CrossAttentionTransformer, + checkpoint_dir: str, + quantization_mode: Optional[str] = None, + fp8_activation_scale_ub: Optional[float] = 1200.0, + device: Optional[torch.device] = None, +) -> Transformer | CrossAttentionTransformer: + if quantization_mode == QuantizationMode.fp8_mixed: + return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device) + elif quantization_mode == QuantizationMode.int4_mixed: + return convert_to_int4_quantized_model(model, checkpoint_dir, device) + else: + raise ValueError(f"Unsupported quantization mode: {quantization_mode}") + + def convert_to_fp8_quantized_model( model: Transformer, - config: MetaReferenceQuantizedInferenceConfig, checkpoint_dir: str, fp8_activation_scale_ub: Optional[float] = 1200.0, + device: Optional[torch.device] = None, ) -> Transformer: - if config.quantization.type == QuantizationType.bf16.value: - return model - - elif config.quantization.type != QuantizationType.fp8.value: - raise ValueError("Only FP8 quantization is supported") - - assert config.model is not None, "Model must be specified for quantized inference" - llama_model = resolve_model(config.model) - assert llama_model is not None, f"Model {config.model} not found" - # Move weights to GPU with quantization - if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value: - log.info("Loading fp8 scales...") - fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt") - assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}" + fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt") + if os.path.isfile(fp8_scales_path): + print("Loading fp8 scales...") fp8_scales = torch.load(fp8_scales_path, weights_only=True) - for block in model.layers: + for _, block in model.named_modules(): if isinstance(block, TransformerBlock): if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): continue @@ -81,8 +75,8 @@ def convert_to_fp8_quantized_model( fp8_activation_scale_ub, ) else: - log.info("Quantizing fp8 weights from bf16...") - for block in model.layers: + print("Quantizing fp8 weights from bf16...") + for _, block in model.named_modules(): if isinstance(block, TransformerBlock): if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): continue @@ -92,12 +86,12 @@ def convert_to_fp8_quantized_model( param.weight = quantize_fp8( param.weight, fp8_activation_scale_ub, - output_device=torch.device("cuda"), + output_device=device, ) for _, parameter in model.named_parameters(): if not isinstance(parameter, Fp8ScaledWeights): - parameter.data = parameter.to(device="cuda") + parameter.data = parameter.to(device=device) return model @@ -290,11 +284,12 @@ def _prepare_model_int4_weight_int8_dynamic_activation( def convert_to_int4_quantized_model( - model: Transformer, - model_args: ModelArgs, -) -> Transformer: + model: Transformer | CrossAttentionTransformer, + checkpoint_dir: str, + device: Optional[torch.device] = None, +) -> Transformer | CrossAttentionTransformer: """Convert the model to int4 quantized model.""" - + model_args = model.params assert model_args.quantization_args is not None, "Quantization args must be specified." quantization_args = model_args.quantization_args if quantization_args.scheme is None: @@ -318,5 +313,4 @@ def convert_to_int4_quantized_model( lora_scale = model_args.lora_args.scale _prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - return cast(Transformer, model.to(device)) + return cast(Transformer | CrossAttentionTransformer, model.to(device=device)) diff --git a/llama_stack/models/llama/llama3/tokenizer.py b/llama_stack/models/llama/llama3/tokenizer.py index b240fa246..d3cc4fc07 100644 --- a/llama_stack/models/llama/llama3/tokenizer.py +++ b/llama_stack/models/llama/llama3/tokenizer.py @@ -4,16 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. - import os from logging import getLogger from pathlib import Path diff --git a/llama_stack/models/llama/llama3_2/__init__.py b/llama_stack/models/llama/llama3_2/__init__.py index 38ee47d66..756f351d8 100644 --- a/llama_stack/models/llama/llama3_2/__init__.py +++ b/llama_stack/models/llama/llama3_2/__init__.py @@ -3,10 +3,3 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. diff --git a/llama_stack/models/llama/llama3_2/prompts_text.py b/llama_stack/models/llama/llama3_2/prompts_text.py index 7bc7e3219..7a1f9887c 100644 --- a/llama_stack/models/llama/llama3_2/prompts_text.py +++ b/llama_stack/models/llama/llama3_2/prompts_text.py @@ -4,12 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. import json import textwrap diff --git a/llama_stack/models/llama/llama3_2/prompts_vision.py b/llama_stack/models/llama/llama3_2/prompts_vision.py index b1ede418b..b0f11cab6 100644 --- a/llama_stack/models/llama/llama3_2/prompts_vision.py +++ b/llama_stack/models/llama/llama3_2/prompts_vision.py @@ -4,13 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - import textwrap from pathlib import Path diff --git a/llama_stack/models/llama/llama4/args.py b/llama_stack/models/llama/llama4/args.py index 046448ef6..6d7c1d409 100644 --- a/llama_stack/models/llama/llama4/args.py +++ b/llama_stack/models/llama/llama4/args.py @@ -4,13 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - from enum import Enum from typing import Optional diff --git a/llama_stack/models/llama/llama4/chat_format.py b/llama_stack/models/llama/llama4/chat_format.py index ebae2b8e5..8f08b3a9e 100644 --- a/llama_stack/models/llama/llama4/chat_format.py +++ b/llama_stack/models/llama/llama4/chat_format.py @@ -12,6 +12,7 @@ from typing import Dict, List, Optional, Tuple import torch from PIL import Image as PIL_Image +# TODO: either fork these or move them to the common package from ..datatypes import ( BuiltinTool, RawContent, @@ -26,10 +27,7 @@ from ..datatypes import ( from ..llama3.tool_utils import ToolUtils from .args import VisionArgs from .datatypes import LLMInput -from .preprocess import ( - ResizeNormalizeImageTransform, - VariableSizeImageTransform, -) +from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform from .tokenizer import Tokenizer @@ -50,7 +48,7 @@ class TransformedImage: aspect_ratio: Tuple[int, int] -def convert_rgba_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image: +def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image: if image.mode == "RGBA": image.load() # for png.split() new_img = PIL_Image.new("RGB", image.size, bg) @@ -167,7 +165,7 @@ class ChatFormat: bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data image = PIL_Image.open(bytes_io) - image = convert_rgba_to_rgb(image) + image = convert_image_to_rgb(image) image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks) if image_tiles.shape[0] > 1: @@ -212,12 +210,9 @@ class ChatFormat: content = ToolUtils.encode_tool_call(t, tool_prompt_format) _process_content(content) - # Tool calls and Tool Response messages should be eom eom = False if message.role == "assistant": - eom = message.stop_reason == StopReason.end_of_message or message.tool_calls - elif message.role == "tool": - eom = True + eom = message.stop_reason == StopReason.end_of_message tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"]) return tokens, images @@ -252,11 +247,6 @@ class ChatFormat: if content.startswith(header_str): content = content[len(header_str) :] - ipython = content.startswith("<|python_start|>") - if ipython: - content = content[len("<|python_start|>") :] - content = content.replace("<|python_end|>", "") - if content.endswith("<|eot|>"): content = content[: -len("<|eot|>")] stop_reason = StopReason.end_of_turn @@ -287,11 +277,6 @@ class ChatFormat: } if tool_name in BuiltinTool.__members__: tool_name = BuiltinTool[tool_name] - elif ipython: - tool_name = BuiltinTool.code_interpreter - tool_arguments = { - "code": content, - } tool_calls = [] if tool_name is not None and tool_arguments is not None: diff --git a/llama_stack/models/llama/llama4/datatypes.py b/llama_stack/models/llama/llama4/datatypes.py index bb1c19a12..27174db63 100644 --- a/llama_stack/models/llama/llama4/datatypes.py +++ b/llama_stack/models/llama/llama4/datatypes.py @@ -4,13 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - from dataclasses import dataclass from typing import List, Optional, Union diff --git a/llama_stack/models/llama/llama4/generation.py b/llama_stack/models/llama/llama4/generation.py index 9c516d967..8971835aa 100644 --- a/llama_stack/models/llama/llama4/generation.py +++ b/llama_stack/models/llama/llama4/generation.py @@ -4,32 +4,43 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. + import codecs import io import json import os import sys import time -from enum import Enum from pathlib import Path from typing import Callable, Generator, List, Optional import torch import torch.nn.functional as F from fairscale.nn.model_parallel.initialize import ( - get_model_parallel_rank, initialize_model_parallel, model_parallel_is_initialized, ) from termcolor import cprint -from ..common import TokenResult +from ..checkpoint import maybe_reshard_state_dict +from ..datatypes import GenerationResult, QuantizationMode from .args import ModelArgs -from .chat_format import ( - ChatFormat, - RawContent, - RawMessage, -) +from .chat_format import ChatFormat, RawContent, RawMessage from .datatypes import LLMInput, MaskedEmbedding, TransformerInput from .model import Transformer from .tokenizer import Tokenizer @@ -37,12 +48,6 @@ from .tokenizer import Tokenizer torch.serialization.add_safe_globals([io.BytesIO, codecs.encode]) -class QuantizationMode(str, Enum): - none = "none" - fp8_mixed = "fp8_mixed" - int4_mixed = "int4_mixed" - - class Llama4: @staticmethod def build( @@ -50,7 +55,7 @@ class Llama4: max_seq_len: int, max_batch_size: int, world_size: Optional[int] = None, - quantization_mode: Optional[str] = None, + quantization_mode: Optional[QuantizationMode] = None, seed: int = 1, ): if not torch.distributed.is_initialized(): @@ -71,11 +76,9 @@ class Llama4: start_time = time.time() - checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) - assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" - assert world_size == len(checkpoints), ( - f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}" - ) + ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth")) + assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}" + print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})") with open(Path(ckpt_dir) / "params.json", "r") as f: params = json.loads(f.read()) @@ -92,10 +95,11 @@ class Llama4: assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch" print("Model args:\n", model_args.model_dump_json(indent=2)) - ckpt_path = checkpoints[get_model_parallel_rank()] - print(f"Loading checkpoint from {ckpt_dir}...") - with open(ckpt_path, "rb") as f: - checkpoint = torch.load(f, map_location="cpu", weights_only=True) + state_dict = maybe_reshard_state_dict( + ckpt_paths, + n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads, + moe_num_experts=model_args.moe_args.num_experts, + ) print("Loaded checkpoint") if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed: from .quantization.loader import convert_to_quantized_model @@ -103,9 +107,9 @@ class Llama4: torch.set_default_tensor_type(torch.BFloat16Tensor) model = Transformer(model_args) print("Loading state dict...") - model.load_state_dict(checkpoint, strict=False) + model.load_state_dict(state_dict, strict=False) print("Done...") - model = convert_to_quantized_model(model, ckpt_dir) + model = convert_to_quantized_model(model, ckpt_dir, quantization_mode) else: if torch.cuda.is_bf16_supported(): torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) @@ -114,7 +118,7 @@ class Llama4: model = Transformer(model_args) print("Loading state dict...") - model.load_state_dict(checkpoint, strict=False) + model.load_state_dict(state_dict, strict=False) print("Done...") print(f"Loaded in {time.time() - start_time:.2f} seconds") @@ -129,7 +133,7 @@ class Llama4: @torch.inference_mode() def generate( self, - llm_input: LLMInput, + llm_inputs: List[LLMInput], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, @@ -137,22 +141,20 @@ class Llama4: echo: bool = False, print_model_input: bool = False, logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - ) -> Generator: + ) -> Generator[List[GenerationResult], None, None]: if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len: max_gen_len = self.model.args.max_seq_len - 1 params = self.model.args print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1" - if print_model_input and get_model_parallel_rank() == 0: - tokens_to_print = list(llm_input.tokens) - cprint( - "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n", - "red", - ) - prompt_tokens = [llm_input.tokens] + if print_model_input: + cprint("Input to model:\n", "yellow") + for inp in llm_inputs: + cprint(self.tokenizer.decode(inp.tokens.tolist()), "grey") + prompt_tokens = [inp.tokens for inp in llm_inputs] - bsz = 1 + bsz = len(llm_inputs) assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) min_prompt_len = min(len(t) for t in prompt_tokens) @@ -175,24 +177,33 @@ class Llama4: input_text_mask = tokens != pad_id if echo: - for i, t in enumerate(llm_input.tokens): - yield TokenResult( - token=t, - text=self.tokenizer.decode([t]), - logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None), - ) + for i in range(max_prompt_len): + results = [] + for j, t in enumerate(tokens[:, i]): + results.append( + GenerationResult( + token=t.item(), + text=self.tokenizer.decode([t.item()]), + source="input", + logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None), + batch_idx=j, + finished=False, + ignore_token=t.item() == pad_id, + ) + ) + yield results stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda") prev_pos = 0 for cur_pos in range(min_prompt_len, total_len): image_embedding = None - if prev_pos == 0 and llm_input.images is not None and len(llm_input.images) > 0: + if prev_pos == 0 and any(inp.images is not None and len(inp.images) > 0 for inp in llm_inputs): image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"] image_mask = image_mask.unsqueeze(-1) h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos]) - image_batch = [llm_input.images] + image_batch = [inp.images if inp.images is not None else [] for inp in llm_inputs] image_embedding = MaskedEmbedding( embedding=self.model.vision_embeddings(image_batch, image_mask, h), mask=image_mask, @@ -228,11 +239,21 @@ class Llama4: ignore_index=pad_id, ) eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens)) - yield TokenResult( - token=next_token[0].item(), - text=self.tokenizer.decode(next_token.tolist()), - logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None), - ) + + results = [] + for idx, t in enumerate(next_token): + results.append( + GenerationResult( + token=t.item(), + text=self.tokenizer.decode([t.item()]), + source="output", + logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None), + batch_idx=idx, + finished=eos_reached[idx], + ignore_token=cur_pos < len(prompt_tokens[idx]), + ) + ) + yield results prev_pos = cur_pos if all(eos_reached): @@ -240,68 +261,47 @@ class Llama4: def completion( self, - content: RawContent, + contents: List[RawContent], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, echo: bool = False, - ) -> Generator: - llm_input = self.formatter.encode_content(content) + ) -> Generator[List[GenerationResult], None, None]: + llm_inputs = [self.formatter.encode_content(c) for c in contents] for result in self.generate( - llm_input=llm_input, + llm_inputs=llm_inputs, temperature=temperature, top_p=top_p, max_gen_len=max_gen_len, logprobs=logprobs, echo=echo, ): - if result.token in self.tokenizer.stop_tokens: - break yield result + if all(r.finished for r in result): + break def chat_completion( self, - messages: List[RawMessage], + messages_batch: List[List[RawMessage]], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, echo: bool = False, - ) -> Generator: - llm_input = self.formatter.encode_dialog_prompt(messages) + ) -> Generator[List[GenerationResult], None, None]: + llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch] for result in self.generate( - llm_input=llm_input, + llm_inputs=llm_inputs, temperature=temperature, top_p=top_p, max_gen_len=max_gen_len, logprobs=logprobs, echo=echo, ): - if result.token in self.tokenizer.stop_tokens: - break yield result - - def chat_completion_raw( - self, - messages: List[RawMessage], - temperature: float = 0.6, - top_p: float = 0.9, - max_gen_len: Optional[int] = None, - logprobs: bool = False, - ): - llm_input = self.formatter.encode_dialog_prompt(messages) - output_tokens = [] - for result in self.generate( - llm_input=llm_input, - temperature=temperature, - top_p=top_p, - max_gen_len=max_gen_len, - logprobs=logprobs, - ): - output_tokens.append(result.token) - - return llm_input.tokens, output_tokens + if all(r.finished for r in result): + break def sample_top_p(probs, p): diff --git a/llama_stack/models/llama/llama4/model.py b/llama_stack/models/llama/llama4/model.py index a35d6857f..08fac7714 100644 --- a/llama_stack/models/llama/llama4/model.py +++ b/llama_stack/models/llama/llama4/model.py @@ -4,16 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. - import math from typing import Any, Dict, List, Optional, Tuple @@ -184,7 +174,6 @@ class Attention(nn.Module): self.head_dim, ) ).cuda() - self.qk_norm = None if self.use_qk_norm: self.qk_norm = L2Norm(args.norm_eps) diff --git a/llama_stack/models/llama/llama4/moe.py b/llama_stack/models/llama/llama4/moe.py index 8cecab7dd..2ce49e915 100644 --- a/llama_stack/models/llama/llama4/moe.py +++ b/llama_stack/models/llama/llama4/moe.py @@ -100,31 +100,21 @@ class Experts(nn.Module): class MoE(torch.nn.Module): """ - This EC implementation is modified from the original EC module. - We refactored the token permutation and unpermutation logic and added support to tp and dp2ep sharding. - This module supports 3 sharding methods of the experts: - - tp: each TP rank has n_experts experts. Experts are sharded following the conventional row/column-parallel TP sharding. - - tp2ep: each TP rank has n_experts/tp experts. Experts are not sharded. - - dp2ep: each EP rank has n_experts/ep experts. Experts are sharded following the row/column-parallel TP sharding. Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor. Several commonly used annotations include: - a: bsz*slen - E: number of experts - e: number of local experts per ep (n_experts/ep) - - et: number of local experts per tp (n_experts/tp) - D: hidden dimension - d: D/tp - F: model dimension - - f: F/tp (used in column/row-parallel linear) - G: number of tokens per expert (a * capacity_factor / E) - g: number of tokens per expert per TP rank (i.e., G/TP) - - GG: G*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=False) - - gg: g*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=True) Examples: x_aD [a, D] routed_in_etG_D [et*G, D] - x_eGGD: [e, GG, D] + x_eGD: [e, G, D] """ def __init__( @@ -207,13 +197,13 @@ class MoE(torch.nn.Module): routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1) out_aD = self.shared_expert(x_aD) - routed_out_egg_D = self.experts(routed_in_EG_D.detach()) + routed_out_eg_D = self.experts(routed_in_EG_D.detach()) router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D) out_aD.scatter_add_( dim=0, index=router_indices_EG_D, - src=routed_out_egg_D.view(-1, D), + src=routed_out_eg_D.view(-1, D), ) out_aD = reduce_from_model_parallel_region(out_aD) return out_aD.view(-1, slen, D) diff --git a/llama_stack/models/llama/llama4/prompts.py b/llama_stack/models/llama/llama4/prompts.py index d4e48e80a..13b96359a 100644 --- a/llama_stack/models/llama/llama4/prompts.py +++ b/llama_stack/models/llama/llama4/prompts.py @@ -4,13 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - import textwrap from io import BytesIO from pathlib import Path diff --git a/llama_stack/models/llama/llama4/quantization/loader.py b/llama_stack/models/llama/llama4/quantization/loader.py index 69aa309fa..f11d83c60 100644 --- a/llama_stack/models/llama/llama4/quantization/loader.py +++ b/llama_stack/models/llama/llama4/quantization/loader.py @@ -6,20 +6,29 @@ import logging import os -from typing import Optional +from typing import Callable, Optional import torch from fairscale.nn.model_parallel.initialize import get_model_parallel_rank -from torch import Tensor +from torch import Tensor, nn from torch.nn import functional as F -from ..generation import QuantizationMode +from ...datatypes import QuantizationMode from ..model import Transformer, TransformerBlock from ..moe import MoE log = logging.getLogger(__name__) +def swiglu_wrapper_no_reduce( + self, + x: Tensor, +): + from ...quantize_impls import ffn_swiglu + + return ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight) + + def experts_batched_swiglu_wrapper( self, x: Tensor, # (e, g, D) @@ -51,24 +60,30 @@ def convert_to_quantized_model( rank = get_model_parallel_rank() + def should_quantize_block(block: nn.Module) -> bool: + if not isinstance(block, TransformerBlock): + return False + + is_moe = isinstance(block.feed_forward, MoE) + if quantization_mode == QuantizationMode.fp8_mixed: + # skip quantization on first and last layers + return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1)) + + return is_moe + use_rich_progress = use_rich_progress and rank == 0 - progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model) + progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block) if quantization_mode == QuantizationMode.int4_mixed: int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt") - int4_zero_points_path = os.path.join(checkpoint_dir, f"int4_zero_points_{rank}.pt") if os.path.isfile(int4_scales_path): log_status(f"Rank {rank}: Loading int4 scales") int4_scales = torch.load(int4_scales_path, weights_only=True) - int4_zero_points = torch.load(int4_zero_points_path, weights_only=True) def apply_quantization(key, weight): scale = int4_scales[key] - zero_point = int4_zero_points[key] return load_int4( weight, scale, - zero_point, - fp8_activation_scale_ub, output_device=torch.device("cuda"), ) @@ -76,7 +91,8 @@ def convert_to_quantized_model( log_status(f"Rank {rank}: Quantizing int4 weights from bf16") def apply_quantization(_, weight): - return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda")) + return quantize_int4(weight, output_device=torch.device("cuda")) + else: fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt") if os.path.isfile(fp8_scales_path): @@ -104,33 +120,38 @@ def convert_to_quantized_model( progress.start() for _, block in model.named_modules(): - if isinstance(block, TransformerBlock): - # Skip quantization on first and last layers - if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): - continue + if not should_quantize_block(block): + continue - # Skip quantization on dense layers - if not isinstance(block.feed_forward, MoE): - continue + update_status(f"Rank {rank} - Layer {block.layer_id}") - update_status(f"Rank {rank} - Layer {block.layer_id}") + # Quantize only routed experts, not shared + prefix = f"layers.{block.layer_id}.feed_forward" + moe = block.feed_forward + moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts) - # Quantize only routed experts, not shared - prefix = f"layers.{block.layer_id}.feed_forward" - moe = block.feed_forward - moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts) + for key in ("w1", "w3", "w2"): + param = getattr(moe.experts, key) + update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}") + setattr( + moe.experts, + key, + apply_quantization( + f"{prefix}.experts.{key}", + param.transpose(1, 2).contiguous(), + ), + ) + if quantization_mode == QuantizationMode.int4_mixed: + # Quantize shared experts + moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert) for key in ("w1", "w3", "w2"): - param = getattr(moe.experts, key) - update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}") - setattr( - moe.experts, - key, - apply_quantization(f"{prefix}.experts.{key}", param.transpose(1, 2).contiguous()), - ) + param = getattr(moe.shared_expert, key) + update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}") + param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight) - processed_blocks += 1 - update_status(message=None, completed=processed_blocks) + processed_blocks += 1 + update_status(message=None, completed=processed_blocks) update_status(f"Rank {rank} - Moving parameters to CUDA") @@ -149,7 +170,12 @@ def convert_to_quantized_model( # fp8/int4 loading can be very slow so we add progress bars to make life slightly better -def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer): +def logging_callbacks( + use_rich_progress: bool, + rank: int, + model: Transformer, + should_quantize_block: Callable[[nn.Module], bool], +): console = None if use_rich_progress: from rich.console import Console @@ -162,15 +188,7 @@ def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer): elif rank == 0: # Only log from rank 0 for non-rich logging log.info(message) - total_blocks = sum( - 1 - for _, block in model.named_modules() - if ( - isinstance(block, TransformerBlock) - and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1)) - and isinstance(block.feed_forward, MoE) - ) - ) + total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block)) progress = None if use_rich_progress: from rich.progress import ( diff --git a/llama_stack/models/llama/llama4/tokenizer.py b/llama_stack/models/llama/llama4/tokenizer.py index c1347daca..14250f681 100644 --- a/llama_stack/models/llama/llama4/tokenizer.py +++ b/llama_stack/models/llama/llama4/tokenizer.py @@ -4,6 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. @@ -59,8 +66,6 @@ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [ "<|text_post_train_reserved_special_token_3|>", "<|text_post_train_reserved_special_token_4|>", "<|text_post_train_reserved_special_token_5|>", - "<|python_start|>", - "<|python_end|>", "<|finetune_right_pad|>", ] + get_reserved_special_tokens( "text_post_train", 61, 6 @@ -85,8 +90,23 @@ LLAMA4_VISION_SPECIAL_TOKENS = [ "vision", 1041, 7 ) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|> +# 201134, ..., 201143 +LLAMA4_REASONING_SPECIAL_TOKENS = [ + "<|reasoning_reserved_special_token_0|>", + "<|reasoning_reserved_special_token_1|>", + "<|reasoning_reserved_special_token_2|>", + "<|reasoning_reserved_special_token_3|>", + "<|reasoning_reserved_special_token_4|>", + "<|reasoning_reserved_special_token_5|>", + "<|reasoning_reserved_special_token_6|>", + "<|reasoning_reserved_special_token_7|>", + "<|reasoning_thinking_start|>", + "<|reasoning_thinking_end|>", +] -LLAMA4_SPECIAL_TOKENS = LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS +LLAMA4_SPECIAL_TOKENS = ( + LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS + LLAMA4_REASONING_SPECIAL_TOKENS +) BASIC_SPECIAL_TOKENS = [ "<|begin_of_text|>", @@ -155,6 +175,9 @@ class Tokenizer: self.eot_id: int = self.special_tokens["<|eot|>"] self.eom_id: int = self.special_tokens["<|eom|>"] + self.thinking_start_id: int = self.special_tokens["<|reasoning_thinking_start|>"] + self.thinking_end_id: int = self.special_tokens["<|reasoning_thinking_end|>"] + self.stop_tokens = [ self.eos_id, self.special_tokens["<|eom|>"], diff --git a/llama_stack/models/llama/llama4/vision/embedding.py b/llama_stack/models/llama/llama4/vision/embedding.py index 73b29cbef..ed7659a73 100644 --- a/llama_stack/models/llama/llama4/vision/embedding.py +++ b/llama_stack/models/llama/llama4/vision/embedding.py @@ -4,13 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - import math from typing import Any, Callable, Dict, List diff --git a/llama_stack/models/llama/prompt_format.py b/llama_stack/models/llama/prompt_format.py index 695c0bf74..6756aebfe 100644 --- a/llama_stack/models/llama/prompt_format.py +++ b/llama_stack/models/llama/prompt_format.py @@ -28,9 +28,6 @@ from llama_stack.models.llama.datatypes import ( ToolPromptFormat, ) from llama_stack.models.llama.llama4.tokenizer import Tokenizer -from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import ( - LLMInput, -) from .llama3.interface import LLama31Interface from .llama3.template_data import ( @@ -38,6 +35,7 @@ from .llama3.template_data import ( system_message_builtin_tools_only, system_message_custom_tools_only, ) +from .llama4.datatypes import LLMInput class TextCompletionContent(BaseModel): diff --git a/llama_stack/models/llama/sku_list.py b/llama_stack/models/llama/sku_list.py index dd3144bb0..513481831 100644 --- a/llama_stack/models/llama/sku_list.py +++ b/llama_stack/models/llama/sku_list.py @@ -4,24 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - from dataclasses import dataclass from functools import lru_cache from typing import List, Optional -from .datatypes import ( +from .sku_types import ( CheckpointQuantizationFormat, CoreModelId, Model, ModelFamily, - SamplingParams, - TopPSamplingStrategy, ) LLAMA2_VOCAB_SIZE = 32000 @@ -47,15 +38,6 @@ def all_registered_models() -> List[Model]: ) -def recommended_sampling_params() -> SamplingParams: - return SamplingParams( - strategy=TopPSamplingStrategy( - temperature=1.0, - top_p=0.9, - ) - ) - - def llama2_family() -> List[Model]: return [ *llama2_base_models(), @@ -150,7 +132,6 @@ def llama2_base_models() -> List[Model]: core_model_id=CoreModelId.llama2_7b, description="Llama 2 7b model", huggingface_repo="meta-llama/Llama-2-7b", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 4096, "n_layers": 32, @@ -169,7 +150,6 @@ def llama2_base_models() -> List[Model]: core_model_id=CoreModelId.llama2_13b, description="Llama 2 13b model", huggingface_repo="meta-llama/Llama-2-13b", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 5120, "n_layers": 40, @@ -188,7 +168,6 @@ def llama2_base_models() -> List[Model]: core_model_id=CoreModelId.llama2_70b, description="Llama 2 70b model", huggingface_repo="meta-llama/Llama-2-70b", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -230,7 +209,6 @@ def llama3_base_models() -> List[Model]: core_model_id=CoreModelId.llama3_70b, description="Llama 3 70b model", huggingface_repo="meta-llama/Llama-3-70B", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -254,7 +232,6 @@ def llama3_1_base_models() -> List[Model]: core_model_id=CoreModelId.llama3_1_8b, description="Llama 3.1 8b model", huggingface_repo="meta-llama/Llama-3.1-8B", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 4096, "n_layers": 32, @@ -273,7 +250,6 @@ def llama3_1_base_models() -> List[Model]: core_model_id=CoreModelId.llama3_1_70b, description="Llama 3.1 70b model", huggingface_repo="meta-llama/Llama-3.1-70B", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -293,7 +269,6 @@ def llama3_1_base_models() -> List[Model]: variant="bf16-mp8", description="Llama 3.1 405b model (BF16 weights)", huggingface_repo="meta-llama/Llama-3.1-405B", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 16384, "n_layers": 126, @@ -313,7 +288,6 @@ def llama3_1_base_models() -> List[Model]: description="Llama 3.1 405b model (FP8 quantized)", huggingface_repo="meta-llama/Llama-3.1-405B-FP8", quantization_format=CheckpointQuantizationFormat.fp8_mixed, - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 16384, "n_layers": 126, @@ -333,7 +307,6 @@ def llama3_1_base_models() -> List[Model]: variant="bf16-mp16", description="Llama 3.1 405b model (BF16 weights for mp16)", huggingface_repo="meta-llama/Llama-3.1-405B", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 16384, "n_layers": 126, @@ -357,7 +330,6 @@ def llama3_2_base_models() -> List[Model]: core_model_id=CoreModelId.llama3_2_1b, description="Llama 3.2 1b model", huggingface_repo="meta-llama/Llama-3.2-1B", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 2048, "n_layers": 16, @@ -376,7 +348,6 @@ def llama3_2_base_models() -> List[Model]: core_model_id=CoreModelId.llama3_2_3b, description="Llama 3.2 3b model", huggingface_repo="meta-llama/Llama-3.2-3B", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 3072, "n_layers": 28, @@ -395,7 +366,6 @@ def llama3_2_base_models() -> List[Model]: core_model_id=CoreModelId.llama3_2_11b_vision, description="Llama 3.2 11b vision model", huggingface_repo="meta-llama/Llama-3.2-11B-Vision", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 4096, "n_layers": 32, @@ -417,7 +387,6 @@ def llama3_2_base_models() -> List[Model]: core_model_id=CoreModelId.llama3_2_90b_vision, description="Llama 3.2 90b vision model", huggingface_repo="meta-llama/Llama-3.2-90B-Vision", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -444,7 +413,6 @@ def llama2_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama2_7b_chat, description="Llama 2 7b chat model", huggingface_repo="meta-llama/Llama-2-7b-chat", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 4096, "n_layers": 32, @@ -463,7 +431,6 @@ def llama2_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama2_13b_chat, description="Llama 2 13b chat model", huggingface_repo="meta-llama/Llama-2-13b-chat", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 5120, "n_layers": 40, @@ -482,7 +449,6 @@ def llama2_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama2_70b_chat, description="Llama 2 70b chat model", huggingface_repo="meta-llama/Llama-2-70b-chat", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -506,7 +472,6 @@ def llama3_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_8b_instruct, description="Llama 3 8b instruct model", huggingface_repo="meta-llama/Llama-3-8B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 4096, "n_layers": 32, @@ -525,7 +490,6 @@ def llama3_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_70b_instruct, description="Llama 3 70b instruct model", huggingface_repo="meta-llama/Llama-3-70B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -549,7 +513,6 @@ def llama3_1_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_1_8b_instruct, description="Llama 3.1 8b instruct model", huggingface_repo="meta-llama/Llama-3.1-8B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 4096, "n_layers": 32, @@ -568,7 +531,6 @@ def llama3_1_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_1_70b_instruct, description="Llama 3.1 70b instruct model", huggingface_repo="meta-llama/Llama-3.1-70B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -588,7 +550,6 @@ def llama3_1_instruct_models() -> List[Model]: variant="bf16-mp8", description="Llama 3.1 405b instruct model (BF16 weights)", huggingface_repo="meta-llama/Llama-3.1-405B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 16384, "n_layers": 126, @@ -608,7 +569,6 @@ def llama3_1_instruct_models() -> List[Model]: description="Llama 3.1 405b instruct model (FP8 quantized)", huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8", quantization_format=CheckpointQuantizationFormat.fp8_mixed, - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 16384, "n_layers": 126, @@ -628,7 +588,6 @@ def llama3_1_instruct_models() -> List[Model]: variant="bf16-mp16", description="Llama 3.1 405b instruct model (BF16 weights for mp16)", huggingface_repo="meta-llama/Llama-3.1-405B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 16384, "n_layers": 126, @@ -684,7 +643,6 @@ def llama3_2_quantized_models() -> List[Model]: quantization_format=CheckpointQuantizationFormat.int4, description="Llama 3.2 1b INT4 quantized LoRA", huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8", - recommended_sampling_params=recommended_sampling_params(), arch_args={ **arch_args_1b(), "quantization_args": { @@ -703,7 +661,6 @@ def llama3_2_quantized_models() -> List[Model]: quantization_format=CheckpointQuantizationFormat.int4, description="Llama 3.2 1b INT4 quantized SpinQuant", huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8", - recommended_sampling_params=recommended_sampling_params(), arch_args={ **arch_args_1b(), "quantization_args": { @@ -718,7 +675,6 @@ def llama3_2_quantized_models() -> List[Model]: quantization_format=CheckpointQuantizationFormat.int4, description="Llama 3.2 3b INT4 quantized LoRA", huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8", - recommended_sampling_params=recommended_sampling_params(), arch_args={ **arch_args_3b(), "quantization_args": { @@ -737,7 +693,6 @@ def llama3_2_quantized_models() -> List[Model]: quantization_format=CheckpointQuantizationFormat.int4, description="Llama 3.2 3b INT4 quantized SpinQuant", huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8", - recommended_sampling_params=recommended_sampling_params(), arch_args={ **arch_args_3b(), "quantization_args": { @@ -755,7 +710,6 @@ def llama3_2_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_2_1b_instruct, description="Llama 3.2 1b instruct model", huggingface_repo="meta-llama/Llama-3.2-1B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args=arch_args_1b(), pth_file_count=1, ), @@ -763,7 +717,6 @@ def llama3_2_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_2_3b_instruct, description="Llama 3.2 3b instruct model", huggingface_repo="meta-llama/Llama-3.2-3B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args=arch_args_3b(), pth_file_count=1, ), @@ -772,7 +725,6 @@ def llama3_2_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_2_11b_vision_instruct, description="Llama 3.2 11b vision instruct model", huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 4096, "n_layers": 32, @@ -794,7 +746,6 @@ def llama3_2_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_2_90b_vision_instruct, description="Llama 3.2 90b vision instruct model", huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -821,7 +772,6 @@ def llama3_3_instruct_models() -> List[Model]: core_model_id=CoreModelId.llama3_3_70b_instruct, description="Llama 3.3 70b instruct", huggingface_repo="meta-llama/Llama-3.3-70B-Instruct", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 8192, "n_layers": 80, @@ -846,7 +796,6 @@ def safety_models() -> List[Model]: core_model_id=CoreModelId.llama_guard_3_11b_vision, description="Llama Guard v3 11b vision system safety model", huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 4096, "n_layers": 32, @@ -870,7 +819,6 @@ def safety_models() -> List[Model]: description="Llama Guard v3 1b 'int4' quantized system safety model", huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4", quantization_format=CheckpointQuantizationFormat.int4, - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 2048, "n_layers": 12, @@ -888,7 +836,6 @@ def safety_models() -> List[Model]: core_model_id=CoreModelId.llama_guard_3_1b, description="Llama Guard v3 1b system safety model", huggingface_repo="meta-llama/Llama-Guard-3-1B", - recommended_sampling_params=recommended_sampling_params(), arch_args={ "dim": 2048, "n_layers": 16, diff --git a/llama_stack/models/llama/sku_types.py b/llama_stack/models/llama/sku_types.py new file mode 100644 index 000000000..88799b66d --- /dev/null +++ b/llama_stack/models/llama/sku_types.py @@ -0,0 +1,229 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Any, Dict, Optional + +from pydantic import BaseModel, ConfigDict, Field + + +class CheckpointQuantizationFormat(Enum): + # default format + bf16 = "bf16" + + # used for enabling fp8_rowwise inference, some weights are bf16 + fp8_mixed = "fp8-mixed" + + int8 = "int8" + + int4 = "int4" + + +class ModelFamily(Enum): + llama2 = "llama2" + llama3 = "llama3" + llama3_1 = "llama3_1" + llama3_2 = "llama3_2" + llama3_3 = "llama3_3" + llama4 = "llama4" + safety = "safety" + + +class CoreModelId(Enum): + """Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)""" + + # Llama 2 family + llama2_7b = "Llama-2-7b" + llama2_13b = "Llama-2-13b" + llama2_70b = "Llama-2-70b" + llama2_7b_chat = "Llama-2-7b-chat" + llama2_13b_chat = "Llama-2-13b-chat" + llama2_70b_chat = "Llama-2-70b-chat" + + # Llama 3 family + llama3_8b = "Llama-3-8B" + llama3_70b = "Llama-3-70B" + llama3_8b_instruct = "Llama-3-8B-Instruct" + llama3_70b_instruct = "Llama-3-70B-Instruct" + + # Llama 3.1 family + llama3_1_8b = "Llama3.1-8B" + llama3_1_70b = "Llama3.1-70B" + llama3_1_405b = "Llama3.1-405B" + llama3_1_8b_instruct = "Llama3.1-8B-Instruct" + llama3_1_70b_instruct = "Llama3.1-70B-Instruct" + llama3_1_405b_instruct = "Llama3.1-405B-Instruct" + + # Llama 3.2 family + llama3_2_1b = "Llama3.2-1B" + llama3_2_3b = "Llama3.2-3B" + llama3_2_1b_instruct = "Llama3.2-1B-Instruct" + llama3_2_3b_instruct = "Llama3.2-3B-Instruct" + llama3_2_11b_vision = "Llama3.2-11B-Vision" + llama3_2_90b_vision = "Llama3.2-90B-Vision" + llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct" + llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct" + + # Llama 3.3 family + llama3_3_70b_instruct = "Llama3.3-70B-Instruct" + + # Llama 4 family + llama4_scout_17b_16e = "Llama-4-Scout-17B-16E" + llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct" + llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E" + llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct" + + # Safety models + llama_guard_3_8b = "Llama-Guard-3-8B" + llama_guard_2_8b = "Llama-Guard-2-8B" + llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision" + llama_guard_3_1b = "Llama-Guard-3-1B" + + +def is_multimodal(model_id) -> bool: + if model_id in [ + CoreModelId.llama3_2_11b_vision, + CoreModelId.llama3_2_90b_vision, + CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_90b_vision_instruct, + ]: + return True + else: + return False + + +def model_family(model_id) -> ModelFamily: + if model_id in [ + CoreModelId.llama2_7b, + CoreModelId.llama2_13b, + CoreModelId.llama2_70b, + CoreModelId.llama2_7b_chat, + CoreModelId.llama2_13b_chat, + CoreModelId.llama2_70b_chat, + ]: + return ModelFamily.llama2 + elif model_id in [ + CoreModelId.llama3_8b, + CoreModelId.llama3_70b, + CoreModelId.llama3_8b_instruct, + CoreModelId.llama3_70b_instruct, + ]: + return ModelFamily.llama3 + elif model_id in [ + CoreModelId.llama3_1_8b, + CoreModelId.llama3_1_70b, + CoreModelId.llama3_1_405b, + CoreModelId.llama3_1_8b_instruct, + CoreModelId.llama3_1_70b_instruct, + CoreModelId.llama3_1_405b_instruct, + ]: + return ModelFamily.llama3_1 + elif model_id in [ + CoreModelId.llama3_2_1b, + CoreModelId.llama3_2_3b, + CoreModelId.llama3_2_1b_instruct, + CoreModelId.llama3_2_3b_instruct, + CoreModelId.llama3_2_11b_vision, + CoreModelId.llama3_2_90b_vision, + CoreModelId.llama3_2_11b_vision_instruct, + CoreModelId.llama3_2_90b_vision_instruct, + ]: + return ModelFamily.llama3_2 + elif model_id in [ + CoreModelId.llama3_3_70b_instruct, + ]: + return ModelFamily.llama3_3 + elif model_id in [ + CoreModelId.llama4_scout_17b_16e, + CoreModelId.llama4_scout_17b_16e_instruct, + CoreModelId.llama4_maverick_17b_128e, + CoreModelId.llama4_maverick_17b_128e_instruct, + ]: + return ModelFamily.llama4 + elif model_id in [ + CoreModelId.llama_guard_3_8b, + CoreModelId.llama_guard_2_8b, + CoreModelId.llama_guard_3_11b_vision, + CoreModelId.llama_guard_3_1b, + ]: + return ModelFamily.safety + else: + raise ValueError(f"Unknown model family for {model_id}") + + +class Model(BaseModel): + core_model_id: CoreModelId + description: str + huggingface_repo: Optional[str] = None + arch_args: Dict[str, Any] + variant: str = "" + + quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16 + pth_file_count: int + metadata: Dict[str, Any] = Field(default_factory=dict) + + # silence pydantic until we remove the `model_` fields + model_config = ConfigDict(protected_namespaces=()) + + @property + def model_family(self) -> ModelFamily: + return model_family(self.core_model_id) + + # The SKU is uniquely identified by (model_id, variant) combo + def descriptor(self, shorten_default_variant: bool = True) -> str: + if not self.variant: + return self.core_model_id.value + return f"{self.core_model_id.value}:{self.variant}" + + @property + def is_instruct_model(self) -> bool: + return "instruct" in self.core_model_id.value + + # Featured models are shown in the non-exhaustive model list + @property + def is_featured(self) -> bool: + return self.model_family in [ + ModelFamily.llama3_1, + ModelFamily.llama3_2, + ModelFamily.llama3_3, + ModelFamily.llama4, + ModelFamily.safety, + ] + + @property + def max_seq_length(self) -> int: + if self.model_family == ModelFamily.llama2: + return 4096 + elif self.core_model_id == CoreModelId.llama_guard_2_8b: + return 4096 + elif self.model_family == ModelFamily.llama3: + return 8192 + elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]: + return 131072 + elif self.model_family == ModelFamily.llama3_2: + if self.quantization_format == CheckpointQuantizationFormat.int4: + return 8192 + return 131072 + elif self.model_family == ModelFamily.llama4: + if self.core_model_id in { + CoreModelId.llama4_scout_17b_16e, + CoreModelId.llama4_maverick_17b_128e, + }: + return 262144 + if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct: + return 10485760 + if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct: + return 1048576 + + raise AssertionError(f"Unexpected core model id: {self.core_model_id}") + elif self.core_model_id in [ + CoreModelId.llama_guard_3_8b, + CoreModelId.llama_guard_3_11b_vision, + CoreModelId.llama_guard_3_1b, + ]: + return 131072 + else: + raise ValueError(f"Unknown max_seq_len for {self.core_model_id}") diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index e1af4ab71..6840da89f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -52,6 +52,7 @@ from llama_stack.apis.inference import ( StopReason, SystemMessage, ToolDefinition, + ToolParamDefinition, ToolResponse, ToolResponseMessage, UserMessage, @@ -63,7 +64,6 @@ from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( BuiltinTool, ToolCall, - ToolParamDefinition, ) from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.telemetry import tracing diff --git a/llama_stack/providers/inline/inference/meta_reference/generators.py b/llama_stack/providers/inline/inference/meta_reference/generators.py index 809351164..10e597665 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generators.py +++ b/llama_stack/providers/inline/inference/meta_reference/generators.py @@ -12,20 +12,19 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken from llama_stack.apis.inference import ( Fp8QuantizationConfig, + GreedySamplingStrategy, Int4QuantizationConfig, JsonSchemaResponseFormat, ResponseFormat, -) -from llama_stack.models.llama.datatypes import ( - GreedySamplingStrategy, - Model, SamplingParams, TopPSamplingStrategy, ) +from llama_stack.models.llama.datatypes import QuantizationMode from llama_stack.models.llama.llama3.generation import Llama3 from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama4.generation import Llama4 from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer +from llama_stack.models.llama.sku_types import Model from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, @@ -136,9 +135,9 @@ class Llama4Generator: if isinstance(config, MetaReferenceQuantizedInferenceConfig): if isinstance(config.quantization, Fp8QuantizationConfig): - quantization_mode = "fp8_mixed" + quantization_mode = QuantizationMode.fp8_mixed elif isinstance(config.quantization, Int4QuantizationConfig): - quantization_mode = "int4_mixed" + quantization_mode = QuantizationMode.int4_mixed else: raise ValueError(f"Unsupported quantization mode {config.quantization}") else: @@ -225,9 +224,9 @@ class Llama3Generator: if isinstance(config, MetaReferenceQuantizedInferenceConfig): if isinstance(config.quantization, Fp8QuantizationConfig): - quantization_mode = "fp8_mixed" + quantization_mode = QuantizationMode.fp8_mixed elif isinstance(config.quantization, Int4QuantizationConfig): - quantization_mode = "int4_mixed" + quantization_mode = QuantizationMode.int4_mixed else: raise ValueError(f"Unsupported quantization mode {config.quantization}") else: @@ -240,6 +239,9 @@ class Llama3Generator: world_size=llama_model.pth_file_count, quantization_mode=quantization_mode, ) + self.tokenizer = self.inner_generator.tokenizer + self.args = self.inner_generator.args + self.formatter = self.inner_generator.formatter def completion( self, diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index da217728b..ca2f51ac7 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -31,23 +31,21 @@ from llama_stack.apis.inference import ( LogProbConfig, Message, ResponseFormat, + SamplingParams, + StopReason, TokenLogProbs, ToolChoice, ToolConfig, -) -from llama_stack.apis.models import Model, ModelType -from llama_stack.models.llama.datatypes import ( - ModelFamily, - SamplingParams, - StopReason, ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.models import Model, ModelType from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.sku_list import resolve_model +from llama_stack.models.llama.sku_types import ModelFamily from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, diff --git a/llama_stack/providers/inline/inference/vllm/openai_utils.py b/llama_stack/providers/inline/inference/vllm/openai_utils.py index 90b5398f9..d34f5ad5f 100644 --- a/llama_stack/providers/inline/inference/vllm/openai_utils.py +++ b/llama_stack/providers/inline/inference/vllm/openai_utils.py @@ -14,9 +14,10 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, Message, ToolChoice, + ToolDefinition, UserMessage, ) -from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition +from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 256e0f821..ea2643b7a 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -46,6 +46,8 @@ from llama_stack.apis.inference import ( TokenLogProbs, ToolChoice, ToolConfig, + TopKSamplingStrategy, + TopPSamplingStrategy, ) from llama_stack.apis.models import Model from llama_stack.log import get_logger @@ -55,8 +57,6 @@ from llama_stack.models.llama.datatypes import ( ToolCall, ToolDefinition, ToolPromptFormat, - TopKSamplingStrategy, - TopPSamplingStrategy, ) from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index f8a1c0436..a040ca1b0 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -22,8 +22,8 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b from torchtune.modules.transforms import Transform from llama_stack.apis.post_training import DatasetFormat -from llama_stack.models.llama.datatypes import Model from llama_stack.models.llama.sku_list import resolve_model +from llama_stack.models.llama.sku_types import Model BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildTokenizerCallable = Callable[..., Llama3Tokenizer] diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index e514e3781..d95c40976 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -23,7 +23,8 @@ from llama_stack.apis.safety import ( ) from llama_stack.apis.shields import Shield from llama_stack.distribution.datatypes import Api -from llama_stack.models.llama.datatypes import CoreModelId, Role +from llama_stack.models.llama.datatypes import Role +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, diff --git a/llama_stack/providers/remote/inference/bedrock/models.py b/llama_stack/providers/remote/inference/bedrock/models.py index c5079799f..ec8120049 100644 --- a/llama_stack/providers/remote/inference/bedrock/models.py +++ b/llama_stack/providers/remote/inference/bedrock/models.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( build_hf_repo_model_entry, ) diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index a53e6e5a5..43d986b86 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -28,8 +28,8 @@ from llama_stack.apis.inference import ( ToolConfig, ToolDefinition, ToolPromptFormat, + TopKSamplingStrategy, ) -from llama_stack.models.llama.datatypes import TopKSamplingStrategy from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) diff --git a/llama_stack/providers/remote/inference/cerebras/models.py b/llama_stack/providers/remote/inference/cerebras/models.py index 37419bf4c..38301b32a 100644 --- a/llama_stack/providers/remote/inference/cerebras/models.py +++ b/llama_stack/providers/remote/inference/cerebras/models.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( build_hf_repo_model_entry, ) diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 53a9c04f4..0eaf0135b 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -28,7 +28,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, build_hf_repo_model_entry, diff --git a/llama_stack/providers/remote/inference/fireworks/models.py b/llama_stack/providers/remote/inference/fireworks/models.py index a0dc11768..4975d061f 100644 --- a/llama_stack/providers/remote/inference/fireworks/models.py +++ b/llama_stack/providers/remote/inference/fireworks/models.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.models.models import ModelType -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ProviderModelEntry, build_hf_repo_model_entry, diff --git a/llama_stack/providers/remote/inference/nvidia/models.py b/llama_stack/providers/remote/inference/nvidia/models.py index 879855003..964125148 100644 --- a/llama_stack/providers/remote/inference/nvidia/models.py +++ b/llama_stack/providers/remote/inference/nvidia/models.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.models import ModelType -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ProviderModelEntry, build_hf_repo_model_entry, diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 5caf19fda..e1f5d7a6a 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -29,15 +29,13 @@ from llama_stack.apis.inference import ( LogProbConfig, Message, ResponseFormat, + SamplingParams, TextTruncation, ToolChoice, ToolConfig, -) -from llama_stack.models.llama.datatypes import ( - SamplingParams, ToolDefinition, - ToolPromptFormat, ) +from llama_stack.models.llama.datatypes import ToolPromptFormat from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index 0582cb816..3f2769b26 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -19,11 +19,9 @@ from llama_stack.apis.inference import ( CompletionRequest, CompletionResponse, CompletionResponseStreamChunk, + GreedySamplingStrategy, JsonSchemaResponseFormat, TokenLogProbs, -) -from llama_stack.models.llama.datatypes import ( - GreedySamplingStrategy, TopKSamplingStrategy, TopPSamplingStrategy, ) diff --git a/llama_stack/providers/remote/inference/ollama/models.py b/llama_stack/providers/remote/inference/ollama/models.py index be556762c..42e364105 100644 --- a/llama_stack/providers/remote/inference/ollama/models.py +++ b/llama_stack/providers/remote/inference/ollama/models.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.models.models import ModelType -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ProviderModelEntry, build_hf_repo_model_entry, diff --git a/llama_stack/providers/remote/inference/sambanova/models.py b/llama_stack/providers/remote/inference/sambanova/models.py index 2231be22d..9589ea268 100644 --- a/llama_stack/providers/remote/inference/sambanova/models.py +++ b/llama_stack/providers/remote/inference/sambanova/models.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( build_hf_repo_model_entry, ) diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 635a42d38..a3badd468 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -21,6 +21,7 @@ from llama_stack.apis.inference import ( CompletionMessage, EmbeddingsResponse, EmbeddingTaskType, + GreedySamplingStrategy, Inference, LogProbConfig, Message, @@ -35,12 +36,9 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ToolResponseMessage, - UserMessage, -) -from llama_stack.models.llama.datatypes import ( - GreedySamplingStrategy, TopKSamplingStrategy, TopPSamplingStrategy, + UserMessage, ) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( diff --git a/llama_stack/providers/remote/inference/together/models.py b/llama_stack/providers/remote/inference/together/models.py index 63d3d94b5..f014c03f0 100644 --- a/llama_stack/providers/remote/inference/together/models.py +++ b/llama_stack/providers/remote/inference/together/models.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.models.models import ModelType -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ProviderModelEntry, build_hf_repo_model_entry, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 6a828322f..e2f3a7b33 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -96,7 +96,6 @@ def _convert_to_vllm_tool_calls_in_response( call_id=call.id, tool_name=call.function.name, arguments=json.loads(call.function.arguments), - arguments_json=call.function.arguments, ) for call in tool_calls ] @@ -176,7 +175,6 @@ async def _process_vllm_chat_completion_stream_response( call_id=tool_call_buf.call_id, tool_name=tool_call_buf.tool_name, arguments=args, - arguments_json=args_str, ), parse_status=ToolCallParseStatus.succeeded, ), diff --git a/llama_stack/providers/remote/post_training/nvidia/models.py b/llama_stack/providers/remote/post_training/nvidia/models.py index 04a9af38c..7c696ac20 100644 --- a/llama_stack/providers/remote/post_training/nvidia/models.py +++ b/llama_stack/providers/remote/post_training/nvidia/models.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.models.llama.datatypes import CoreModelId +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( ProviderModelEntry, build_hf_repo_model_entry, diff --git a/llama_stack/providers/tests/report.py b/llama_stack/providers/tests/report.py index c9a7f69a8..bc29534be 100644 --- a/llama_stack/providers/tests/report.py +++ b/llama_stack/providers/tests/report.py @@ -12,8 +12,8 @@ import pytest from pytest import ExitCode from pytest_html.basereport import _process_outcome -from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_list import all_registered_models +from llama_stack.models.llama.sku_types import CoreModelId INFERENCE_APIS = ["chat_completion"] FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"] diff --git a/llama_stack/providers/utils/inference/__init__.py b/llama_stack/providers/utils/inference/__init__.py index a885da235..e36be9404 100644 --- a/llama_stack/providers/utils/inference/__init__.py +++ b/llama_stack/providers/utils/inference/__init__.py @@ -6,8 +6,8 @@ from typing import List -from llama_stack.models.llama.datatypes import * # noqa: F403 from llama_stack.models.llama.sku_list import all_registered_models +from llama_stack.models.llama.sku_types import * # noqa: F403 def is_supported_safety_model(model: Model) -> bool: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index e475d77b6..44a89dfb0 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -73,21 +73,21 @@ from llama_stack.apis.inference import ( CompletionMessage, CompletionResponse, CompletionResponseStreamChunk, + GreedySamplingStrategy, Message, + SamplingParams, SystemMessage, TokenLogProbs, ToolResponseMessage, + TopKSamplingStrategy, + TopPSamplingStrategy, UserMessage, ) from llama_stack.models.llama.datatypes import ( BuiltinTool, - GreedySamplingStrategy, - SamplingParams, StopReason, ToolCall, ToolDefinition, - TopKSamplingStrategy, - TopPSamplingStrategy, ) from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 0231312cc..4f9c4927a 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -34,7 +34,6 @@ from llama_stack.apis.inference import ( ) from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( - ModelFamily, RawContent, RawContentItem, RawMediaItem, @@ -43,7 +42,6 @@ from llama_stack.models.llama.datatypes import ( Role, StopReason, ToolPromptFormat, - is_multimodal, ) from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.prompt_templates import ( @@ -55,6 +53,7 @@ from llama_stack.models.llama.llama3.prompt_templates import ( ) from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.sku_list import resolve_model +from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.providers.utils.inference import supported_inference_models log = get_logger(name=__name__, category="inference") diff --git a/pyproject.toml b/pyproject.toml index 8d8ff4338..8ae7ddbb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,9 +224,9 @@ exclude = [ "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^llama_stack/providers/inline/inference/meta_reference/config\\.py$", "^llama_stack/providers/inline/inference/meta_reference/inference\\.py$", - "^llama_stack/providers/inline/inference/meta_reference/llama3/generation\\.py$", - "^llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model\\.py$", - "^llama_stack/providers/inline/inference/meta_reference/llama4/", + "^llama_stack/models/llama/llama3/generation\\.py$", + "^llama_stack/models/llama/llama3/multimodal/model\\.py$", + "^llama_stack/models/llama/llama4/", "^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$", diff --git a/tests/integration/report.py b/tests/integration/report.py index c07338ce6..a50f51d3f 100644 --- a/tests/integration/report.py +++ b/tests/integration/report.py @@ -11,7 +11,6 @@ import pytest from pytest import CollectReport from termcolor import cprint -from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_list import ( all_registered_models, llama3_1_instruct_models, @@ -20,6 +19,7 @@ from llama_stack.models.llama.sku_list import ( llama3_instruct_models, safety_models, ) +from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.datatypes import Api from .metadata import API_MAPS