diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 567110829..f94f2a578 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -4163,80 +4163,70 @@
]
},
"arguments": {
- "oneOf": [
- {
- "type": "string"
- },
- {
- "type": "object",
- "additionalProperties": {
- "oneOf": [
- {
- "type": "string"
- },
- {
- "type": "integer"
- },
- {
- "type": "number"
- },
- {
- "type": "boolean"
- },
- {
- "type": "null"
- },
- {
- "type": "array",
- "items": {
- "oneOf": [
- {
- "type": "string"
- },
- {
- "type": "integer"
- },
- {
- "type": "number"
- },
- {
- "type": "boolean"
- },
- {
- "type": "null"
- }
- ]
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "integer"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "null"
+ },
+ {
+ "type": "array",
+ "items": {
+ "oneOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "integer"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "null"
}
- },
- {
- "type": "object",
- "additionalProperties": {
- "oneOf": [
- {
- "type": "string"
- },
- {
- "type": "integer"
- },
- {
- "type": "number"
- },
- {
- "type": "boolean"
- },
- {
- "type": "null"
- }
- ]
+ ]
+ }
+ },
+ {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "integer"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "null"
}
- }
- ]
+ ]
+ }
}
- }
- ]
- },
- "arguments_json": {
- "type": "string"
+ ]
+ }
}
},
"additionalProperties": false,
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 1dfd17f55..238f8dcd0 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -2890,34 +2890,30 @@ components:
title: BuiltinTool
- type: string
arguments:
- oneOf:
- - type: string
- - type: object
- additionalProperties:
- oneOf:
- - type: string
- - type: integer
- - type: number
- - type: boolean
- - type: 'null'
- - type: array
- items:
- oneOf:
- - type: string
- - type: integer
- - type: number
- - type: boolean
- - type: 'null'
- - type: object
- additionalProperties:
- oneOf:
- - type: string
- - type: integer
- - type: number
- - type: boolean
- - type: 'null'
- arguments_json:
- type: string
+ type: object
+ additionalProperties:
+ oneOf:
+ - type: string
+ - type: integer
+ - type: number
+ - type: boolean
+ - type: 'null'
+ - type: array
+ items:
+ oneOf:
+ - type: string
+ - type: integer
+ - type: number
+ - type: boolean
+ - type: 'null'
+ - type: object
+ additionalProperties:
+ oneOf:
+ - type: string
+ - type: integer
+ - type: number
+ - type: boolean
+ - type: 'null'
additionalProperties: false
required:
- call_id
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index 1d4012c19..216935ede 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -25,15 +25,64 @@ from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import (
BuiltinTool,
- SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
+ ToolParamDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
+register_schema(ToolCall)
+register_schema(ToolParamDefinition)
+register_schema(ToolDefinition)
+
+
+@json_schema_type
+class GreedySamplingStrategy(BaseModel):
+ type: Literal["greedy"] = "greedy"
+
+
+@json_schema_type
+class TopPSamplingStrategy(BaseModel):
+ type: Literal["top_p"] = "top_p"
+ temperature: Optional[float] = Field(..., gt=0.0)
+ top_p: Optional[float] = 0.95
+
+
+@json_schema_type
+class TopKSamplingStrategy(BaseModel):
+ type: Literal["top_k"] = "top_k"
+ top_k: int = Field(..., ge=1)
+
+
+SamplingStrategy = Annotated[
+ Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
+ Field(discriminator="type"),
+]
+register_schema(SamplingStrategy, name="SamplingStrategy")
+
+
+@json_schema_type
+class SamplingParams(BaseModel):
+ """Sampling parameters.
+
+ :param strategy: The sampling strategy.
+ :param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
+ your prompt plus max_tokens cannot exceed the model's context length.
+ :param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
+ based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
+ :param stop: Up to 4 sequences where the API will stop generating further tokens.
+ The returned text will not contain the stop sequence.
+ """
+
+ strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
+
+ max_tokens: Optional[int] = 0
+ repetition_penalty: Optional[float] = 1.0
+ stop: Optional[List[str]] = None
+
class LogProbConfig(BaseModel):
"""
diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py
index fc3e7008f..9694bf22d 100644
--- a/llama_stack/cli/download.py
+++ b/llama_stack/cli/download.py
@@ -29,8 +29,8 @@ from rich.progress import (
from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand
-from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
+from llama_stack.models.llama.sku_types import Model
class Download(Subcommand):
diff --git a/llama_stack/cli/model/describe.py b/llama_stack/cli/model/describe.py
index f347bdf8d..62dde36e8 100644
--- a/llama_stack/cli/model/describe.py
+++ b/llama_stack/cli/model/describe.py
@@ -63,17 +63,6 @@ class ModelDescribe(Subcommand):
("Model params.json", json.dumps(model.arch_args, indent=4)),
]
- if model.recommended_sampling_params is not None:
- sampling_params = model.recommended_sampling_params.model_dump()
- for k in ("max_tokens", "repetition_penalty"):
- del sampling_params[k]
- rows.append(
- (
- "Recommended sampling params",
- json.dumps(sampling_params, indent=4),
- )
- )
-
print_table(
rows,
headers,
diff --git a/llama_stack/cli/model/prompt_format.py b/llama_stack/cli/model/prompt_format.py
index 3ce77655b..673487812 100644
--- a/llama_stack/cli/model/prompt_format.py
+++ b/llama_stack/cli/model/prompt_format.py
@@ -11,7 +11,7 @@ from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
-from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
+from llama_stack.models.llama.sku_types import CoreModelId, ModelFamily, is_multimodal, model_family
ROOT_DIR = Path(__file__).parent.parent.parent
diff --git a/llama_stack/cli/model/safety_models.py b/llama_stack/cli/model/safety_models.py
index c81783f60..131d055aa 100644
--- a/llama_stack/cli/model/safety_models.py
+++ b/llama_stack/cli/model/safety_models.py
@@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from typing import Any, Dict, Optional
+from typing import Any, Dict
from pydantic import BaseModel, ConfigDict, Field
-from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
+from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
class PromptGuardModel(BaseModel):
@@ -23,7 +23,6 @@ class PromptGuardModel(BaseModel):
is_instruct_model: bool = False
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
arch_args: Dict[str, Any] = Field(default_factory=dict)
- recommended_sampling_params: Optional[SamplingParams] = None
def descriptor(self) -> str:
return self.model_id
diff --git a/llama_stack/models/llama/checkpoint.py b/llama_stack/models/llama/checkpoint.py
new file mode 100644
index 000000000..2bae08a69
--- /dev/null
+++ b/llama_stack/models/llama/checkpoint.py
@@ -0,0 +1,164 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import concurrent.futures
+import re
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
+
+
+def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
+ """Map a new MP rank to a list of old MP ranks given a change in MP size."""
+ if new_mp_size % old_mp_size == 0:
+ # Read old MP shard and split it into smaller ones
+ return [new_mp_rank * old_mp_size // new_mp_size]
+ elif old_mp_size % new_mp_size == 0:
+ # Merge old MP shards into a single one
+ mp_factor = old_mp_size // new_mp_size
+ return list(range(new_mp_rank * mp_factor, (new_mp_rank + 1) * mp_factor))
+ else:
+ raise ValueError(
+ f"Either old MP size or new MP size should be a multiple of the other: "
+ f"{old_mp_size} % {new_mp_size} != 0 and {new_mp_size} % {old_mp_size} != 0"
+ )
+
+
+def maybe_reshard_state_dict(
+ ckpt_paths: List[Path],
+ n_kv_heads: int,
+ moe_num_experts: Optional[int] = None,
+ map_location: Union[str, torch.device] = "cpu",
+ mmap: bool = True,
+) -> Dict[str, torch.Tensor]:
+ if str(map_location) == "cpu":
+ torch.set_default_tensor_type(torch.BFloat16Tensor)
+ else:
+ torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
+
+ ckpt_paths = np.array(sorted(ckpt_paths))
+
+ new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
+ old_mp_size = len(ckpt_paths)
+ old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
+
+ print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
+ paths = ckpt_paths[old_mp_ranks] # type: ignore
+ state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
+
+ if new_mp_size == old_mp_size:
+ return state_dicts[0] # type: ignore
+
+ if moe_num_experts is not None:
+ state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts]
+
+ print(f"Resharding {len(state_dicts)} state dicts from MP size {old_mp_size} to MP size {new_mp_size}")
+ return reshard_mp(
+ state_dicts,
+ size=max(new_mp_size // old_mp_size, 1),
+ rank=new_mp_rank % max(new_mp_size // old_mp_size, 1),
+ repeat_qk_qv=max(new_mp_size // n_kv_heads, 1),
+ )
+
+
+_WEIGHT_ROW_KEY = {
+ "feed_forward.w2",
+ "feed_forward.mlp.fc2",
+ "attention.wo",
+ "feed_forward.mlp.fc2_weight",
+ "feed_forward.w_out_shared_DF.weight",
+ "attn.wo.weight",
+ "mlp.c_proj.weight",
+}
+_MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"}
+
+_WEIGHT_COLUMN_KEY = {
+ "output",
+ "feed_forward.(w1|w3)",
+ "feed_forward.mlp.(fc1|fc3)",
+ "feed_forward.mlp.fc1_weight",
+ "attention.(wk|wq|wv|wqkv).weight",
+ "feed_forward.(w_in_shared_FD|w_swiglu_FD)",
+ "attn.(wk|wq|wv).weight",
+ "attn.(wk|wq|wv).bias",
+ "mlp.c_fc.weight",
+ "mlp.c_fc.bias",
+ "conv1._linear.weight",
+ "tok_embeddings.weight",
+ "vision_projection.weight",
+}
+_MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
+
+
+def reshard_mp(
+ state_dicts: List[Dict[str, torch.Tensor]],
+ size: int,
+ rank: int,
+ repeat_qk_qv: int = 1,
+) -> Dict[str, torch.Tensor]:
+ """
+ Reshard a list of state dicts into a single state dict given a change in MP size.
+ If the list has more than one state dict, we concatenate the values of the same
+ key across all state dicts. Otherwise, we just slice it for the current MP rank.
+ """
+
+ def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
+ if len(tensors) > 1:
+ return torch.cat(tensors, dim=dim)
+ return tensors[0].chunk(size, dim=dim)[rank].clone()
+
+ def process_key(key: str) -> torch.Tensor:
+ if row_regex.search(key):
+ return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
+ elif column_regex.search(key):
+ if "w13" in key or "fc1_weight" in key:
+ dims = state_dicts[0][key].size()
+ values = [s[key].view(2, dims[0] // 2, *dims[1:]) for s in state_dicts]
+ return concat_or_chunk(values, dim=1).flatten(0, 1)
+ elif "qkv" in key:
+ q_dim = state_dicts[0][key.replace("qkv", "o")].size(1)
+ kv_dim = (state_dicts[0][key].size(0) - q_dim) // 2
+ values = [s[key].split((q_dim, kv_dim, kv_dim)) for s in state_dicts]
+ return torch.cat([concat_or_chunk(x, dim=0) for x in zip(*values, strict=False)]) # type: ignore
+ elif "wk.weight" in key or "wv.weight" in key:
+ # Support MP > #kv_head
+ return concat_or_chunk([s[key].repeat(repeat_qk_qv, 1) for s in state_dicts], dim=0)
+ elif key == "output.bias" or key == "fc.weight":
+ return concat_or_chunk([s[key] for s in state_dicts], dim=0)
+ elif "w_" in key:
+ return concat_or_chunk([s[key] for s in state_dicts], dim=-2)
+ else:
+ return concat_or_chunk([s[key] for s in state_dicts], dim=0)
+ else:
+ return state_dicts[0][key].clone()
+
+ row_keys = _WEIGHT_ROW_KEY | _MOE_WEIGHT_ROW_KEY
+ column_keys = _WEIGHT_COLUMN_KEY | _MOE_WEIGHT_COLUMN_KEY
+
+ column_regex = re.compile("|".join(column_keys))
+ row_regex = re.compile("|".join(row_keys))
+
+ output: Dict[str, torch.Tensor] = {}
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ # Note: only processes keys in the first state dict.
+ # Assumes keys are the same across all state dicts.
+ mappings = {executor.submit(process_key, key): key for key in state_dicts[0]}
+ for future in concurrent.futures.as_completed(mappings):
+ output[mappings[future]] = future.result()
+ return output
+
+
+def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
+ routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
+ routed_regex = re.compile("|".join(routed_keys))
+ keys = list(state_dict.keys())
+ for key in keys:
+ if routed_regex.search(key):
+ state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0)
+ return state_dict
diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py
index ef791da8f..106875bb2 100644
--- a/llama_stack/models/llama/datatypes.py
+++ b/llama_stack/models/llama/datatypes.py
@@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# top-level folder for each specific model found within the models/ directory at
-# the top-level of this source tree.
-
import base64
from enum import Enum
from io import BytesIO
@@ -19,8 +12,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from typing_extensions import Annotated
-from llama_stack.schema_utils import json_schema_type, register_schema
-
# The goal is that these set of types are relevant for all Llama models.
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
# the llama3 series of models.
@@ -47,14 +38,7 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
- # Plan is to deprecate the Dict in favor of a JSON string
- # that is parsed on the client side instead of trying to manage
- # the recursive type here.
- # Making this a union so that client side can start prepping for this change.
- # Eventually, we will remove both the Dict and arguments_json field,
- # and arguments will just be a str
- arguments: Union[str, Dict[str, RecursiveType]]
- arguments_json: Optional[str] = None
+ arguments: Dict[str, RecursiveType]
@field_validator("tool_name", mode="before")
@classmethod
@@ -98,6 +82,29 @@ class StopReason(Enum):
out_of_tokens = "out_of_tokens"
+class ToolParamDefinition(BaseModel):
+ param_type: str
+ description: Optional[str] = None
+ required: Optional[bool] = True
+ default: Optional[Any] = None
+
+
+class ToolDefinition(BaseModel):
+ tool_name: Union[BuiltinTool, str]
+ description: Optional[str] = None
+ parameters: Optional[Dict[str, ToolParamDefinition]] = None
+
+ @field_validator("tool_name", mode="before")
+ @classmethod
+ def validate_field(cls, v):
+ if isinstance(v, str):
+ try:
+ return BuiltinTool(v)
+ except ValueError:
+ return v
+ return v
+
+
class RawMediaItem(BaseModel):
type: Literal["image"] = "image"
data: bytes | BytesIO
@@ -140,292 +147,25 @@ class RawMessage(BaseModel):
tool_calls: List[ToolCall] = Field(default_factory=list)
-register_schema(ToolCall)
+class GenerationResult(BaseModel):
+ token: int
+ text: str
+ logprobs: Optional[List[float]] = None
+
+ source: Literal["input"] | Literal["output"]
+
+ # index within the batch
+ batch_idx: int
+ # whether generation for this item is already finished. note that tokens can
+ # get returned even afterwards since other items in the batch can still be generating tokens
+ finished: bool
+ # because a batch is parallel processed, useful decoding for one item can correspond to processing
+ # pad tokens or tokens beyond EOS for other items. we could have decided to return None for this case
+ # but it's more convenient to return a list of GenerationResult and filter out the ignored tokens
+ ignore_token: bool
-@json_schema_type
-class ToolParamDefinition(BaseModel):
- param_type: str
- description: Optional[str] = None
- required: Optional[bool] = True
- default: Optional[Any] = None
-
-
-@json_schema_type
-class ToolDefinition(BaseModel):
- tool_name: Union[BuiltinTool, str]
- description: Optional[str] = None
- parameters: Optional[Dict[str, ToolParamDefinition]] = None
-
- @field_validator("tool_name", mode="before")
- @classmethod
- def validate_field(cls, v):
- if isinstance(v, str):
- try:
- return BuiltinTool(v)
- except ValueError:
- return v
- return v
-
-
-@json_schema_type
-class GreedySamplingStrategy(BaseModel):
- type: Literal["greedy"] = "greedy"
-
-
-@json_schema_type
-class TopPSamplingStrategy(BaseModel):
- type: Literal["top_p"] = "top_p"
- temperature: Optional[float] = Field(..., gt=0.0)
- top_p: Optional[float] = 0.95
-
-
-@json_schema_type
-class TopKSamplingStrategy(BaseModel):
- type: Literal["top_k"] = "top_k"
- top_k: int = Field(..., ge=1)
-
-
-SamplingStrategy = Annotated[
- Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
- Field(discriminator="type"),
-]
-register_schema(SamplingStrategy, name="SamplingStrategy")
-
-
-@json_schema_type
-class SamplingParams(BaseModel):
- """Sampling parameters.
-
- :param strategy: The sampling strategy.
- :param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
- your prompt plus max_tokens cannot exceed the model's context length.
- :param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
- based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
- :param stop: Up to 4 sequences where the API will stop generating further tokens.
- The returned text will not contain the stop sequence.
- """
-
- strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
-
- max_tokens: Optional[int] = 0
- repetition_penalty: Optional[float] = 1.0
- stop: Optional[List[str]] = None
-
-
-class CheckpointQuantizationFormat(Enum):
- # default format
- bf16 = "bf16"
-
- # used for enabling fp8_rowwise inference, some weights are bf16
- fp8_mixed = "fp8-mixed"
-
- int8 = "int8"
-
- int4 = "int4"
-
-
-class ModelFamily(Enum):
- llama2 = "llama2"
- llama3 = "llama3"
- llama3_1 = "llama3_1"
- llama3_2 = "llama3_2"
- llama3_3 = "llama3_3"
- llama4 = "llama4"
- safety = "safety"
-
-
-class CoreModelId(Enum):
- """Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
-
- # Llama 2 family
- llama2_7b = "Llama-2-7b"
- llama2_13b = "Llama-2-13b"
- llama2_70b = "Llama-2-70b"
- llama2_7b_chat = "Llama-2-7b-chat"
- llama2_13b_chat = "Llama-2-13b-chat"
- llama2_70b_chat = "Llama-2-70b-chat"
-
- # Llama 3 family
- llama3_8b = "Llama-3-8B"
- llama3_70b = "Llama-3-70B"
- llama3_8b_instruct = "Llama-3-8B-Instruct"
- llama3_70b_instruct = "Llama-3-70B-Instruct"
-
- # Llama 3.1 family
- llama3_1_8b = "Llama3.1-8B"
- llama3_1_70b = "Llama3.1-70B"
- llama3_1_405b = "Llama3.1-405B"
- llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
- llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
- llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
-
- # Llama 3.2 family
- llama3_2_1b = "Llama3.2-1B"
- llama3_2_3b = "Llama3.2-3B"
- llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
- llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
- llama3_2_11b_vision = "Llama3.2-11B-Vision"
- llama3_2_90b_vision = "Llama3.2-90B-Vision"
- llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
- llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
-
- # Llama 3.3 family
- llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
-
- # Llama 4 family
- llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
- llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
- llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
- llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
-
- # Safety models
- llama_guard_3_8b = "Llama-Guard-3-8B"
- llama_guard_2_8b = "Llama-Guard-2-8B"
- llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
- llama_guard_3_1b = "Llama-Guard-3-1B"
-
-
-def is_multimodal(model_id) -> bool:
- if model_id in [
- CoreModelId.llama3_2_11b_vision,
- CoreModelId.llama3_2_90b_vision,
- CoreModelId.llama3_2_11b_vision_instruct,
- CoreModelId.llama3_2_90b_vision_instruct,
- ]:
- return True
- else:
- return False
-
-
-def model_family(model_id) -> ModelFamily:
- if model_id in [
- CoreModelId.llama2_7b,
- CoreModelId.llama2_13b,
- CoreModelId.llama2_70b,
- CoreModelId.llama2_7b_chat,
- CoreModelId.llama2_13b_chat,
- CoreModelId.llama2_70b_chat,
- ]:
- return ModelFamily.llama2
- elif model_id in [
- CoreModelId.llama3_8b,
- CoreModelId.llama3_70b,
- CoreModelId.llama3_8b_instruct,
- CoreModelId.llama3_70b_instruct,
- ]:
- return ModelFamily.llama3
- elif model_id in [
- CoreModelId.llama3_1_8b,
- CoreModelId.llama3_1_70b,
- CoreModelId.llama3_1_405b,
- CoreModelId.llama3_1_8b_instruct,
- CoreModelId.llama3_1_70b_instruct,
- CoreModelId.llama3_1_405b_instruct,
- ]:
- return ModelFamily.llama3_1
- elif model_id in [
- CoreModelId.llama3_2_1b,
- CoreModelId.llama3_2_3b,
- CoreModelId.llama3_2_1b_instruct,
- CoreModelId.llama3_2_3b_instruct,
- CoreModelId.llama3_2_11b_vision,
- CoreModelId.llama3_2_90b_vision,
- CoreModelId.llama3_2_11b_vision_instruct,
- CoreModelId.llama3_2_90b_vision_instruct,
- ]:
- return ModelFamily.llama3_2
- elif model_id in [
- CoreModelId.llama3_3_70b_instruct,
- ]:
- return ModelFamily.llama3_3
- elif model_id in [
- CoreModelId.llama4_scout_17b_16e,
- CoreModelId.llama4_scout_17b_16e_instruct,
- CoreModelId.llama4_maverick_17b_128e,
- CoreModelId.llama4_maverick_17b_128e_instruct,
- ]:
- return ModelFamily.llama4
- elif model_id in [
- CoreModelId.llama_guard_3_8b,
- CoreModelId.llama_guard_2_8b,
- CoreModelId.llama_guard_3_11b_vision,
- CoreModelId.llama_guard_3_1b,
- ]:
- return ModelFamily.safety
- else:
- raise ValueError(f"Unknown model family for {model_id}")
-
-
-class Model(BaseModel):
- core_model_id: CoreModelId
- description: str
- huggingface_repo: Optional[str] = None
- recommended_sampling_params: Optional[SamplingParams] = None
- arch_args: Dict[str, Any]
- variant: str = ""
-
- quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
- pth_file_count: int
- metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
-
- # silence pydantic until we remove the `model_` fields
- model_config = ConfigDict(protected_namespaces=())
-
- @property
- def model_family(self) -> ModelFamily:
- return model_family(self.core_model_id)
-
- # The SKU is uniquely identified by (model_id, variant) combo
- def descriptor(self, shorten_default_variant: bool = True) -> str:
- if not self.variant:
- return self.core_model_id.value
- return f"{self.core_model_id.value}:{self.variant}"
-
- @property
- def is_instruct_model(self) -> bool:
- return "instruct" in self.id.name
-
- # Featured models are shown in the non-exhaustive model list
- @property
- def is_featured(self) -> bool:
- return self.model_family in [
- ModelFamily.llama3_1,
- ModelFamily.llama3_2,
- ModelFamily.llama3_3,
- ModelFamily.llama4,
- ModelFamily.safety,
- ]
-
- @property
- def max_seq_length(self) -> int:
- if self.model_family == ModelFamily.llama2:
- return 4096
- elif self.core_model_id == CoreModelId.llama_guard_2_8b:
- return 4096
- elif self.model_family == ModelFamily.llama3:
- return 8192
- elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
- return 131072
- elif self.model_family == ModelFamily.llama3_2:
- if self.quantization_format == CheckpointQuantizationFormat.int4:
- return 8192
- return 131072
- elif self.model_family == ModelFamily.llama4:
- if self.core_model_id in {
- CoreModelId.llama4_scout_17b_16e,
- CoreModelId.llama4_maverick_17b_128e,
- }:
- return 262144
- if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
- return 10485760
- if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
- return 1048576
- elif self.core_model_id in [
- CoreModelId.llama_guard_3_8b,
- CoreModelId.llama_guard_3_11b_vision,
- CoreModelId.llama_guard_3_1b,
- ]:
- return 131072
- else:
- raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
+class QuantizationMode(str, Enum):
+ none = "none"
+ fp8_mixed = "fp8_mixed"
+ int4_mixed = "int4_mixed"
diff --git a/llama_stack/models/llama/llama3/args.py b/llama_stack/models/llama/llama3/args.py
index e96eaca61..f7e4b4557 100644
--- a/llama_stack/models/llama/llama3/args.py
+++ b/llama_stack/models/llama/llama3/args.py
@@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# top-level folder for each specific model found within the models/ directory at
-# the top-level of this source tree.
-
from dataclasses import dataclass
from enum import Enum
from typing import Optional
diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py
index 8ae911fc3..f55cd5e1c 100644
--- a/llama_stack/models/llama/llama3/chat_format.py
+++ b/llama_stack/models/llama/llama3/chat_format.py
@@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# top-level folder for each specific model found within the models/ directory at
-# the top-level of this source tree.
-
import io
import json
import uuid
diff --git a/llama_stack/models/llama/llama3/generation.py b/llama_stack/models/llama/llama3/generation.py
index b4e0d39b9..ee99a07ba 100644
--- a/llama_stack/models/llama/llama3/generation.py
+++ b/llama_stack/models/llama/llama3/generation.py
@@ -4,59 +4,37 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# top-level folder for each specific model found within the models/ directory at
+# the top-level of this source tree.
import json
import os
import sys
import time
-from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
- get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from termcolor import cprint
-from ..datatypes import RawContent, RawMessage, StopReason, ToolPromptFormat
+from ..checkpoint import maybe_reshard_state_dict
+from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
from .args import ModelArgs
from .chat_format import ChatFormat, LLMInput
from .model import Transformer
+from .multimodal.model import CrossAttentionTransformer
from .tokenizer import Tokenizer
-@dataclass
-class CompletionPrediction:
- generation: str
- decoded_tokens: Optional[List[str]] = None
- logprobs: Optional[List[List[float]]] = None
-
-
-@dataclass
-class ChatPrediction:
- generation: RawMessage
- decoded_tokens: Optional[List[str]] = None
- logprobs: Optional[List[List[float]]] = None
-
-
-@dataclass
-class TokenResult:
- token: int
- text: str
- logprobs: Optional[List[float]] = None
-
-
-# TODO: make this completely parallel to the llama4 generation.py file and share common code
-# from llama-models also
class Llama3:
@staticmethod
def build(
@@ -64,7 +42,7 @@ class Llama3:
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
- tokenizer_path: Optional[str] = None,
+ quantization_mode: Optional[QuantizationMode] = None,
seed: int = 1,
device: str = "cuda",
):
@@ -101,13 +79,9 @@ class Llama3:
start_time = time.time()
- checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
- assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
- assert world_size == len(checkpoints), (
- f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
- )
- ckpt_path = checkpoints[get_model_parallel_rank()]
- checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
+ ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
+ assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
+ print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
@@ -116,40 +90,58 @@ class Llama3:
max_batch_size=max_batch_size,
**params,
)
- if tokenizer_path:
- tokenizer = Tokenizer(model_path=tokenizer_path)
- else:
- tokenizer = Tokenizer.get_instance()
+ tokenizer = Tokenizer.get_instance()
+
+ state_dict = maybe_reshard_state_dict(
+ ckpt_paths,
+ n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
+ )
assert model_args.vocab_size == tokenizer.n_words
- torch.set_default_device(device)
- if device.type == "cuda":
- if torch.cuda.is_bf16_supported():
- torch.set_default_dtype(torch.bfloat16)
- else:
- torch.set_default_dtype(torch.half)
- elif device.type == "xpu":
- if torch.xpu.is_bf16_supported():
- torch.set_default_dtype(torch.bfloat16)
- else:
- torch.set_default_dtype(torch.half)
- else:
- torch.set_default_dtype(torch.half)
- if model_args.vision_chunk_size > 0:
- from .multimodal.model import CrossAttentionTransformer
+ def build_model():
+ if model_args.vision_chunk_size > 0:
+ model = CrossAttentionTransformer(model_args)
+ model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
+ else:
+ model = Transformer(model_args)
+ return model
- model = CrossAttentionTransformer(model_args)
- model.setup_cache(model_args.max_batch_size, torch.get_default_dtype())
+ if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
+ from .quantization.loader import convert_to_quantized_model
+
+ torch.set_default_tensor_type(torch.BFloat16Tensor)
+ model = build_model()
+ print("Loading state dict...")
+ model.load_state_dict(state_dict, strict=False)
+ print("Done...")
+ model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
+ torch.set_default_device(device)
else:
- model = Transformer(model_args)
- model.load_state_dict(checkpoint, strict=True)
- model.to(device)
+ print(f"Setting default device to {device}")
+ torch.set_default_device(device)
+ if device.type == "cuda":
+ if torch.cuda.is_bf16_supported():
+ torch.set_default_dtype(torch.bfloat16)
+ else:
+ torch.set_default_dtype(torch.half)
+ elif device.type == "xpu":
+ if torch.xpu.is_bf16_supported():
+ torch.set_default_dtype(torch.bfloat16)
+ else:
+ torch.set_default_dtype(torch.half)
+
+ model = build_model()
+ print("Loading state dict...")
+ model.load_state_dict(state_dict, strict=True)
+ model.to(device)
+ print("Done...")
+
print(f"Loaded in {time.time() - start_time:.2f} seconds")
- return Llama(model, tokenizer, model_args)
+ return Llama3(model, tokenizer, model_args)
- def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
+ def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
@@ -158,26 +150,30 @@ class Llama3:
@torch.inference_mode()
def generate(
self,
- model_input: LLMInput,
- max_gen_len: int,
+ model_inputs: List[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
+ max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
- ) -> Generator:
+ ) -> Generator[List[GenerationResult], None, None]:
+ if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
+ max_gen_len = self.args.max_seq_len - 1
params = self.model.params
+ print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input:
- tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens]
- cprint(
- "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
- "red",
- )
- prompt_tokens = [model_input.tokens]
+ for inp in model_inputs:
+ tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
+ cprint(
+ "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
+ "red",
+ )
+ prompt_tokens = [inp.tokens for inp in model_inputs]
- bsz = 1
+ bsz = len(model_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
@@ -189,18 +185,6 @@ class Llama3:
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
- is_vision = not isinstance(self.model, Transformer)
- if is_vision:
- images = model_input.vision.images if model_input.vision is not None else []
- mask = model_input.vision.mask if model_input.vision is not None else []
-
- # the method works for bsz > 1 so add a batch dimension
- xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
- batch_images=[images],
- batch_masks=[mask],
- total_len=total_len,
- )
-
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
for k, t in enumerate(prompt_tokens):
@@ -208,23 +192,45 @@ class Llama3:
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
- prev_pos = 0
+ is_vision = not isinstance(self.model, Transformer)
+ if is_vision:
+ images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
+ mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
+
+ xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
+ batch_images=images,
+ batch_masks=mask,
+ total_len=total_len,
+ device=tokens.device,
+ )
+
eos_reached = torch.tensor([False] * bsz)
input_text_mask = tokens != pad_id
if echo:
- for i, t in enumerate(model_input.tokens):
- yield TokenResult(
- token=t,
- text=self.tokenizer.decode([t]),
- logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None),
- )
+ for i in range(max_prompt_len):
+ results = []
+ for j, t in enumerate(tokens[:, i]):
+ results.append(
+ GenerationResult(
+ token=t.item(),
+ text=self.tokenizer.decode([t.item()]),
+ source="input",
+ logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
+ batch_idx=j,
+ finished=False,
+ ignore_token=t.item() == pad_id,
+ )
+ )
+ yield results
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
+
+ prev_pos = 0
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
- text_only_inference = model_input.vision is None
+ text_only_inference = all(inp.vision is None for inp in model_inputs)
logits = self.model.forward(
position_ids,
tokens,
@@ -271,155 +277,69 @@ class Llama3:
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
- yield TokenResult(
- token=next_token[0].item(),
- text=self.tokenizer.decode(next_token.tolist()),
- logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
- )
+ results = []
+ for idx, t in enumerate(next_token):
+ results.append(
+ GenerationResult(
+ token=t.item(),
+ text=self.tokenizer.decode([t.item()]),
+ source="output",
+ logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
+ batch_idx=idx,
+ finished=eos_reached[idx],
+ ignore_token=cur_pos < len(prompt_tokens[idx]),
+ )
+ )
+ yield results
prev_pos = cur_pos
if all(eos_reached):
break
- def text_completion(
+ def completion(
self,
- content: RawContent,
+ contents: List[RawContent],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
- ) -> CompletionPrediction:
- if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
- max_gen_len = self.model.params.max_seq_len - 1
-
- model_input = self.formatter.encode_content(content)
-
- tokens = []
- token_logprobs = []
- decoded_tokens = []
+ ) -> Generator[List[GenerationResult], None, None]:
+ model_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate(
- model_input=model_input,
- max_gen_len=max_gen_len,
+ model_inputs=model_inputs,
temperature=temperature,
top_p=top_p,
+ max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
- tokens.append(result.token)
- if logprobs:
- decoded_tokens.append(result.text)
- token_logprobs.append(result.logprobs)
-
- generation = self.tokenizer.decode(tokens)
- if logprobs:
- return CompletionPrediction(
- generation=generation,
- logprobs=token_logprobs,
- decoded_tokens=decoded_tokens,
- )
-
- return CompletionPrediction(generation=generation)
+ yield result
+ if all(r.finished for r in result):
+ break
def chat_completion(
self,
- messages: List[RawMessage],
+ messages_batch: List[List[RawMessage]],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
echo: bool = False,
- ) -> ChatPrediction:
- if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
- max_gen_len = self.model.params.max_seq_len - 1
-
- tokens = []
- token_logprobs = []
- decoded_tokens = []
-
- stop_reason = None
+ ) -> Generator[List[GenerationResult], None, None]:
+ model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate(
- model_input=self.formatter.encode_dialog_prompt(messages, tool_prompt_format),
- max_gen_len=max_gen_len,
+ model_inputs=model_inputs,
temperature=temperature,
top_p=top_p,
+ max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
- tokens.append(result.token)
- if result.text == "<|eot_id|>":
- stop_reason = StopReason.end_of_turn
- elif result.text == "<|eom_id|>":
- stop_reason = StopReason.end_of_message
-
- if logprobs:
- decoded_tokens.append(result.text)
- token_logprobs.append(result.logprobs)
-
- if stop_reason is None:
- stop_reason = StopReason.out_of_tokens
-
- message = self.formatter.decode_assistant_message(tokens, stop_reason)
-
- if logprobs:
- return ChatPrediction(
- generation=message,
- logprobs=token_logprobs,
- decoded_tokens=decoded_tokens,
- )
-
- return ChatPrediction(generation=message)
-
- def chat_completion_raw(
- self,
- messages: List[RawMessage],
- temperature: float = 0.6,
- top_p: float = 0.9,
- max_gen_len: Optional[int] = None,
- tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
- ) -> List[int]:
- if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
- max_gen_len = self.model.params.max_seq_len - 1
-
- output_tokens = []
- model_input = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
- input_tokens = model_input.tokens
- for result in self.generate(
- model_input=model_input,
- max_gen_len=max_gen_len,
- temperature=temperature,
- top_p=top_p,
- logprobs=False,
- ):
- output_tokens.append(result.token)
-
- return input_tokens, output_tokens
-
- def text_completion_raw(
- self,
- content: RawContent,
- temperature: float = 0.6,
- top_p: float = 0.9,
- max_gen_len: Optional[int] = None,
- ):
- if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
- max_gen_len = self.model.params.max_seq_len - 1
-
- model_input = self.formatter.encode_content(content)
- input_tokens = model_input.tokens
-
- output_tokens = []
- for result in self.generate(
- model_input=model_input,
- max_gen_len=max_gen_len,
- temperature=temperature,
- top_p=top_p,
- logprobs=False,
- ):
- output_tokens.append(result.token)
-
- return input_tokens, output_tokens
+ yield result
+ if all(r.finished for r in result):
+ break
def sample_top_p(probs, p):
diff --git a/llama_stack/models/llama/llama3/model.py b/llama_stack/models/llama/llama3/model.py
index a49167980..2562673e2 100644
--- a/llama_stack/models/llama/llama3/model.py
+++ b/llama_stack/models/llama/llama3/model.py
@@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# top-level folder for each specific model found within the models/ directory at
-# the top-level of this source tree.
-
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
-
import math
from typing import Optional, Tuple
@@ -29,6 +19,10 @@ from torch import nn
from .args import ModelArgs
+# **NOTE**: This code is not runnable without installing `torch` and `fairscale`
+# dependencies. These dependencies are not part of the default dependencies
+# (requirements.txt) of the `llama-models` package.
+
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
@@ -111,9 +105,9 @@ class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
- model_parallel_size = fs_init.get_model_parallel_world_size()
- self.n_local_heads = args.n_heads // model_parallel_size
- self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
+ world_size = fs_init.get_model_parallel_world_size()
+ self.n_local_heads = args.n_heads // world_size
+ self.n_local_kv_heads = self.n_kv_heads // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
diff --git a/llama_stack/models/llama/llama3/multimodal/model.py b/llama_stack/models/llama/llama3/multimodal/model.py
index 3d0d77c87..0cb18b948 100644
--- a/llama_stack/models/llama/llama3/multimodal/model.py
+++ b/llama_stack/models/llama/llama3/multimodal/model.py
@@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# top-level folder for each specific model found within the models/ directory at
-# the top-level of this source tree.
-
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
import logging
import math
from functools import partial
@@ -180,14 +170,14 @@ class ImageAttention(nn.Module):
n_heads,
):
super().__init__()
- model_parallel_size = fs_init.get_model_parallel_world_size()
+ world_size = fs_init.get_model_parallel_world_size()
qkvo_replication = 1
- if model_parallel_size > 16:
- qkvo_replication = model_parallel_size // 8
+ if world_size > 16:
+ qkvo_replication = world_size // 8
self.n_kv_heads = n_heads
- self.n_local_heads = n_heads * qkvo_replication // model_parallel_size
- self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // model_parallel_size
+ self.n_local_heads = n_heads * qkvo_replication // world_size
+ self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = dim // n_heads
@@ -536,16 +526,16 @@ class Attention(nn.Module):
cache_v (torch.Tensor): Cached values for attention.
"""
super().__init__()
- model_parallel_size = fs_init.get_model_parallel_world_size()
+ world_size = fs_init.get_model_parallel_world_size()
replication_factor = 1
- if model_parallel_size > 8:
- replication_factor = model_parallel_size // MP_SCALE
+ if world_size > 8:
+ replication_factor = world_size // MP_SCALE
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_kv_heads *= replication_factor
- self.n_local_heads = args.n_heads // model_parallel_size
- self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
+ self.n_local_heads = args.n_heads // world_size
+ self.n_local_kv_heads = self.n_kv_heads // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.max_seq_len = args.max_seq_len
@@ -587,13 +577,11 @@ class Attention(nn.Module):
self.n_local_kv_heads,
self.head_dim,
)
- device = next(self.parameters()).device
self.register_buffer(
"key_cache",
torch.zeros(
cache_shape,
dtype=dtype,
- device=device,
),
persistent=False,
)
@@ -602,7 +590,6 @@ class Attention(nn.Module):
torch.zeros(
cache_shape,
dtype=dtype,
- device=device,
),
persistent=False,
)
@@ -614,6 +601,9 @@ class Attention(nn.Module):
freqs_cis: torch.Tensor,
position_ids: torch.LongTensor,
):
+ self.key_cache = self.key_cache.to(x.device)
+ self.value_cache = self.value_cache.to(x.device)
+
xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
bs, slen, _ = xq.shape
@@ -832,10 +822,10 @@ class CrossAttention(torch.nn.Module):
norm_eps: float,
):
super().__init__()
- self.model_parallel_size = fs_init.get_model_parallel_world_size()
+ self.world_size = fs_init.get_model_parallel_world_size()
replication_factor = 1
- if self.model_parallel_size > 8:
- replication_factor = self.model_parallel_size // MP_SCALE
+ if self.world_size > 8:
+ replication_factor = self.world_size // MP_SCALE
n_kv_heads *= replication_factor
assert n_heads % n_kv_heads == 0
@@ -889,10 +879,10 @@ class CrossAttention(torch.nn.Module):
# trunk LLM (i.e., group query attention) -- @dubeya
# local heads
assert self.n_heads % self.n_kv_heads == 0
- assert self.n_heads % self.model_parallel_size == 0
- assert self.n_kv_heads % self.model_parallel_size == 0
- self.n_local_heads = self.n_heads // self.model_parallel_size
- self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
+ assert self.n_heads % self.world_size == 0
+ assert self.n_kv_heads % self.world_size == 0
+ self.n_local_heads = self.n_heads // self.world_size
+ self.n_local_kv_heads = self.n_kv_heads // self.world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
@@ -1041,7 +1031,7 @@ class CrossAttentionTransformerVision(torch.nn.Module):
self.image_res = args.vision_chunk_size
self.max_num_chunks = args.vision_max_num_chunks
if return_intermediate is not None:
- return_intermediate = [int(level) for level in return_intermediate.split(",")]
+ return_intermediate = [int(layer) for layer in return_intermediate.split(",")]
self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
self.patch_size = 14
self.vision_encoder = VisionEncoder(
@@ -1076,15 +1066,15 @@ class CrossAttentionTransformerText(torch.nn.Module):
def __init__(self, args: ModelArgs) -> None:
super().__init__()
- self.model_parallel_size = fs_init.get_model_parallel_world_size()
+ self.world_size = fs_init.get_model_parallel_world_size()
assert args.vocab_size > 0
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
- self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
- assert self.vocab_size % self.model_parallel_size == 0
+ self.n_local_kv_heads = self.n_kv_heads // self.world_size
+ assert self.vocab_size % self.world_size == 0
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
self.pos_embeddings = None
# final norm layer (not necessary for post-norm)
@@ -1184,6 +1174,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
text_only_inference: bool = False,
):
assert self.cache_is_setup, "Please set up cache before calling forward"
+ self.mask_cache = self.mask_cache.to(h.device)
+ self.freqs_cis = self.freqs_cis.to(h.device)
mask = self.mask_cache.index_select(2, position_ids)
freqs_cis = self.freqs_cis.index_select(0, position_ids)
@@ -1212,9 +1204,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
output = gather_from_tensor_model_parallel_region(output)
return output.float()
- def setup_cache(self, max_batch_size: int, dtype=torch.bfloat16):
+ def setup_cache(self, max_batch_size: int, device: torch.device, dtype=torch.bfloat16):
# Set up the text kv caches
- device = next(self.parameters()).device
ones = torch.ones(
(self.max_seq_len, self.max_seq_len),
dtype=torch.bool,
@@ -1265,7 +1256,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
return (
cross_attention_masks.to(device=text_device, dtype=text_dtype),
- full_text_row_masked_out_mask,
+ full_text_row_masked_out_mask.to(device=text_device),
)
@@ -1284,14 +1275,15 @@ class CrossAttentionTransformer(torch.nn.Module):
max_num_chunks=args.vision_max_num_chunks,
)
- def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
- self.text_model.setup_cache(max_batch_size, dtype)
+ def setup_cache(self, max_batch_size: int, device: torch.device, dtype: torch.dtype):
+ self.text_model.setup_cache(max_batch_size, device, dtype)
def compute_vision_tokens_masks(
self,
batch_images: List[List[PIL_Image.Image]],
batch_masks: List[List[List[int]]],
total_len: int,
+ device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
skip_vision_encoder = False
@@ -1318,6 +1310,7 @@ class CrossAttentionTransformer(torch.nn.Module):
image_res=self.params.vision_chunk_size,
max_num_images=max_num_images,
)
+ stacked_images = stacked_images.to(device=device)
if skip_vision_encoder:
vision_tokens = torch.zeros(
@@ -1330,7 +1323,7 @@ class CrossAttentionTransformer(torch.nn.Module):
),
)
else:
- vision_tokens = self.vision_model(stacked_images, aspect_ratios)
+ vision_tokens = self.vision_model(stacked_images, aspect_ratios).to(device=device)
bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
xattn_caches = torch.stack(
diff --git a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py
index e03fcfc93..d4e825a22 100644
--- a/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py
+++ b/llama_stack/models/llama/llama3/prompt_templates/system_prompts.py
@@ -15,7 +15,7 @@ import textwrap
from datetime import datetime
from typing import Any, List, Optional
-from llama_stack.models.llama.datatypes import (
+from llama_stack.apis.inference import (
BuiltinTool,
ToolDefinition,
ToolParamDefinition,
diff --git a/llama_stack/models/llama/llama3/quantization/loader.py b/llama_stack/models/llama/llama3/quantization/loader.py
index f4d94c382..771fd02be 100644
--- a/llama_stack/models/llama/llama3/quantization/loader.py
+++ b/llama_stack/models/llama/llama3/quantization/loader.py
@@ -4,9 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
-
# type: ignore
import os
from typing import Any, Dict, List, Optional, cast
@@ -18,22 +15,15 @@ from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_regi
from torch import Tensor, nn
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
-from llama_stack.apis.inference import QuantizationType
-from llama_stack.log import get_logger
-from llama_stack.models.llama.sku_list import resolve_model
-
-from ...config import MetaReferenceQuantizedInferenceConfig
-from ...datatypes import CheckpointQuantizationFormat
+from ...datatypes import QuantizationMode
from ...quantize_impls import (
Fp8ScaledWeights,
ffn_swiglu,
load_fp8,
quantize_fp8,
)
-from ..args import ModelArgs
from ..model import Transformer, TransformerBlock
-
-log = get_logger(__name__, category="quantization")
+from ..multimodal.model import CrossAttentionTransformer
def swiglu_wrapper(
@@ -44,30 +34,34 @@ def swiglu_wrapper(
return reduce_from_model_parallel_region(out)
+def convert_to_quantized_model(
+ model: Transformer | CrossAttentionTransformer,
+ checkpoint_dir: str,
+ quantization_mode: Optional[str] = None,
+ fp8_activation_scale_ub: Optional[float] = 1200.0,
+ device: Optional[torch.device] = None,
+) -> Transformer | CrossAttentionTransformer:
+ if quantization_mode == QuantizationMode.fp8_mixed:
+ return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
+ elif quantization_mode == QuantizationMode.int4_mixed:
+ return convert_to_int4_quantized_model(model, checkpoint_dir, device)
+ else:
+ raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
+
+
def convert_to_fp8_quantized_model(
model: Transformer,
- config: MetaReferenceQuantizedInferenceConfig,
checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0,
+ device: Optional[torch.device] = None,
) -> Transformer:
- if config.quantization.type == QuantizationType.bf16.value:
- return model
-
- elif config.quantization.type != QuantizationType.fp8.value:
- raise ValueError("Only FP8 quantization is supported")
-
- assert config.model is not None, "Model must be specified for quantized inference"
- llama_model = resolve_model(config.model)
- assert llama_model is not None, f"Model {config.model} not found"
-
# Move weights to GPU with quantization
- if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
- log.info("Loading fp8 scales...")
- fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
- assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
+ fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
+ if os.path.isfile(fp8_scales_path):
+ print("Loading fp8 scales...")
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
- for block in model.layers:
+ for _, block in model.named_modules():
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
@@ -81,8 +75,8 @@ def convert_to_fp8_quantized_model(
fp8_activation_scale_ub,
)
else:
- log.info("Quantizing fp8 weights from bf16...")
- for block in model.layers:
+ print("Quantizing fp8 weights from bf16...")
+ for _, block in model.named_modules():
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
@@ -92,12 +86,12 @@ def convert_to_fp8_quantized_model(
param.weight = quantize_fp8(
param.weight,
fp8_activation_scale_ub,
- output_device=torch.device("cuda"),
+ output_device=device,
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):
- parameter.data = parameter.to(device="cuda")
+ parameter.data = parameter.to(device=device)
return model
@@ -290,11 +284,12 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
def convert_to_int4_quantized_model(
- model: Transformer,
- model_args: ModelArgs,
-) -> Transformer:
+ model: Transformer | CrossAttentionTransformer,
+ checkpoint_dir: str,
+ device: Optional[torch.device] = None,
+) -> Transformer | CrossAttentionTransformer:
"""Convert the model to int4 quantized model."""
-
+ model_args = model.params
assert model_args.quantization_args is not None, "Quantization args must be specified."
quantization_args = model_args.quantization_args
if quantization_args.scheme is None:
@@ -318,5 +313,4 @@ def convert_to_int4_quantized_model(
lora_scale = model_args.lora_args.scale
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
- return cast(Transformer, model.to(device))
+ return cast(Transformer | CrossAttentionTransformer, model.to(device=device))
diff --git a/llama_stack/models/llama/llama3/tokenizer.py b/llama_stack/models/llama/llama3/tokenizer.py
index b240fa246..d3cc4fc07 100644
--- a/llama_stack/models/llama/llama3/tokenizer.py
+++ b/llama_stack/models/llama/llama3/tokenizer.py
@@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# top-level folder for each specific model found within the models/ directory at
-# the top-level of this source tree.
-
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
-
import os
from logging import getLogger
from pathlib import Path
diff --git a/llama_stack/models/llama/llama3_2/__init__.py b/llama_stack/models/llama/llama3_2/__init__.py
index 38ee47d66..756f351d8 100644
--- a/llama_stack/models/llama/llama3_2/__init__.py
+++ b/llama_stack/models/llama/llama3_2/__init__.py
@@ -3,10 +3,3 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# top-level folder for each specific model found within the models/ directory at
-# the top-level of this source tree.
diff --git a/llama_stack/models/llama/llama3_2/prompts_text.py b/llama_stack/models/llama/llama3_2/prompts_text.py
index 7bc7e3219..7a1f9887c 100644
--- a/llama_stack/models/llama/llama3_2/prompts_text.py
+++ b/llama_stack/models/llama/llama3_2/prompts_text.py
@@ -4,12 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# top-level folder for each specific model found within the models/ directory at
-# the top-level of this source tree.
import json
import textwrap
diff --git a/llama_stack/models/llama/llama3_2/prompts_vision.py b/llama_stack/models/llama/llama3_2/prompts_vision.py
index b1ede418b..b0f11cab6 100644
--- a/llama_stack/models/llama/llama3_2/prompts_vision.py
+++ b/llama_stack/models/llama/llama3_2/prompts_vision.py
@@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# top-level folder for each specific model found within the models/ directory at
-# the top-level of this source tree.
-
import textwrap
from pathlib import Path
diff --git a/llama_stack/models/llama/llama4/args.py b/llama_stack/models/llama/llama4/args.py
index 046448ef6..6d7c1d409 100644
--- a/llama_stack/models/llama/llama4/args.py
+++ b/llama_stack/models/llama/llama4/args.py
@@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# top-level folder for each specific model found within the models/ directory at
-# the top-level of this source tree.
-
from enum import Enum
from typing import Optional
diff --git a/llama_stack/models/llama/llama4/chat_format.py b/llama_stack/models/llama/llama4/chat_format.py
index ebae2b8e5..8f08b3a9e 100644
--- a/llama_stack/models/llama/llama4/chat_format.py
+++ b/llama_stack/models/llama/llama4/chat_format.py
@@ -12,6 +12,7 @@ from typing import Dict, List, Optional, Tuple
import torch
from PIL import Image as PIL_Image
+# TODO: either fork these or move them to the common package
from ..datatypes import (
BuiltinTool,
RawContent,
@@ -26,10 +27,7 @@ from ..datatypes import (
from ..llama3.tool_utils import ToolUtils
from .args import VisionArgs
from .datatypes import LLMInput
-from .preprocess import (
- ResizeNormalizeImageTransform,
- VariableSizeImageTransform,
-)
+from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
from .tokenizer import Tokenizer
@@ -50,7 +48,7 @@ class TransformedImage:
aspect_ratio: Tuple[int, int]
-def convert_rgba_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
+def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
if image.mode == "RGBA":
image.load() # for png.split()
new_img = PIL_Image.new("RGB", image.size, bg)
@@ -167,7 +165,7 @@ class ChatFormat:
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
image = PIL_Image.open(bytes_io)
- image = convert_rgba_to_rgb(image)
+ image = convert_image_to_rgb(image)
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
if image_tiles.shape[0] > 1:
@@ -212,12 +210,9 @@ class ChatFormat:
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
_process_content(content)
- # Tool calls and Tool Response messages should be eom
eom = False
if message.role == "assistant":
- eom = message.stop_reason == StopReason.end_of_message or message.tool_calls
- elif message.role == "tool":
- eom = True
+ eom = message.stop_reason == StopReason.end_of_message
tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"])
return tokens, images
@@ -252,11 +247,6 @@ class ChatFormat:
if content.startswith(header_str):
content = content[len(header_str) :]
- ipython = content.startswith("<|python_start|>")
- if ipython:
- content = content[len("<|python_start|>") :]
- content = content.replace("<|python_end|>", "")
-
if content.endswith("<|eot|>"):
content = content[: -len("<|eot|>")]
stop_reason = StopReason.end_of_turn
@@ -287,11 +277,6 @@ class ChatFormat:
}
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
- elif ipython:
- tool_name = BuiltinTool.code_interpreter
- tool_arguments = {
- "code": content,
- }
tool_calls = []
if tool_name is not None and tool_arguments is not None:
diff --git a/llama_stack/models/llama/llama4/datatypes.py b/llama_stack/models/llama/llama4/datatypes.py
index bb1c19a12..27174db63 100644
--- a/llama_stack/models/llama/llama4/datatypes.py
+++ b/llama_stack/models/llama/llama4/datatypes.py
@@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# top-level folder for each specific model found within the models/ directory at
-# the top-level of this source tree.
-
from dataclasses import dataclass
from typing import List, Optional, Union
diff --git a/llama_stack/models/llama/llama4/generation.py b/llama_stack/models/llama/llama4/generation.py
index 9c516d967..8971835aa 100644
--- a/llama_stack/models/llama/llama4/generation.py
+++ b/llama_stack/models/llama/llama4/generation.py
@@ -4,32 +4,43 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# top-level folder for each specific model found within the models/ directory at
+# the top-level of this source tree.
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
+
import codecs
import io
import json
import os
import sys
import time
-from enum import Enum
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
- get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from termcolor import cprint
-from ..common import TokenResult
+from ..checkpoint import maybe_reshard_state_dict
+from ..datatypes import GenerationResult, QuantizationMode
from .args import ModelArgs
-from .chat_format import (
- ChatFormat,
- RawContent,
- RawMessage,
-)
+from .chat_format import ChatFormat, RawContent, RawMessage
from .datatypes import LLMInput, MaskedEmbedding, TransformerInput
from .model import Transformer
from .tokenizer import Tokenizer
@@ -37,12 +48,6 @@ from .tokenizer import Tokenizer
torch.serialization.add_safe_globals([io.BytesIO, codecs.encode])
-class QuantizationMode(str, Enum):
- none = "none"
- fp8_mixed = "fp8_mixed"
- int4_mixed = "int4_mixed"
-
-
class Llama4:
@staticmethod
def build(
@@ -50,7 +55,7 @@ class Llama4:
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
- quantization_mode: Optional[str] = None,
+ quantization_mode: Optional[QuantizationMode] = None,
seed: int = 1,
):
if not torch.distributed.is_initialized():
@@ -71,11 +76,9 @@ class Llama4:
start_time = time.time()
- checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
- assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
- assert world_size == len(checkpoints), (
- f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
- )
+ ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
+ assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
+ print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
@@ -92,10 +95,11 @@ class Llama4:
assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch"
print("Model args:\n", model_args.model_dump_json(indent=2))
- ckpt_path = checkpoints[get_model_parallel_rank()]
- print(f"Loading checkpoint from {ckpt_dir}...")
- with open(ckpt_path, "rb") as f:
- checkpoint = torch.load(f, map_location="cpu", weights_only=True)
+ state_dict = maybe_reshard_state_dict(
+ ckpt_paths,
+ n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
+ moe_num_experts=model_args.moe_args.num_experts,
+ )
print("Loaded checkpoint")
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
from .quantization.loader import convert_to_quantized_model
@@ -103,9 +107,9 @@ class Llama4:
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
print("Loading state dict...")
- model.load_state_dict(checkpoint, strict=False)
+ model.load_state_dict(state_dict, strict=False)
print("Done...")
- model = convert_to_quantized_model(model, ckpt_dir)
+ model = convert_to_quantized_model(model, ckpt_dir, quantization_mode)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
@@ -114,7 +118,7 @@ class Llama4:
model = Transformer(model_args)
print("Loading state dict...")
- model.load_state_dict(checkpoint, strict=False)
+ model.load_state_dict(state_dict, strict=False)
print("Done...")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
@@ -129,7 +133,7 @@ class Llama4:
@torch.inference_mode()
def generate(
self,
- llm_input: LLMInput,
+ llm_inputs: List[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
@@ -137,22 +141,20 @@ class Llama4:
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
- ) -> Generator:
+ ) -> Generator[List[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
max_gen_len = self.model.args.max_seq_len - 1
params = self.model.args
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
- if print_model_input and get_model_parallel_rank() == 0:
- tokens_to_print = list(llm_input.tokens)
- cprint(
- "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
- "red",
- )
- prompt_tokens = [llm_input.tokens]
+ if print_model_input:
+ cprint("Input to model:\n", "yellow")
+ for inp in llm_inputs:
+ cprint(self.tokenizer.decode(inp.tokens.tolist()), "grey")
+ prompt_tokens = [inp.tokens for inp in llm_inputs]
- bsz = 1
+ bsz = len(llm_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
@@ -175,24 +177,33 @@ class Llama4:
input_text_mask = tokens != pad_id
if echo:
- for i, t in enumerate(llm_input.tokens):
- yield TokenResult(
- token=t,
- text=self.tokenizer.decode([t]),
- logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None),
- )
+ for i in range(max_prompt_len):
+ results = []
+ for j, t in enumerate(tokens[:, i]):
+ results.append(
+ GenerationResult(
+ token=t.item(),
+ text=self.tokenizer.decode([t.item()]),
+ source="input",
+ logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
+ batch_idx=j,
+ finished=False,
+ ignore_token=t.item() == pad_id,
+ )
+ )
+ yield results
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
prev_pos = 0
for cur_pos in range(min_prompt_len, total_len):
image_embedding = None
- if prev_pos == 0 and llm_input.images is not None and len(llm_input.images) > 0:
+ if prev_pos == 0 and any(inp.images is not None and len(inp.images) > 0 for inp in llm_inputs):
image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"]
image_mask = image_mask.unsqueeze(-1)
h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos])
- image_batch = [llm_input.images]
+ image_batch = [inp.images if inp.images is not None else [] for inp in llm_inputs]
image_embedding = MaskedEmbedding(
embedding=self.model.vision_embeddings(image_batch, image_mask, h),
mask=image_mask,
@@ -228,11 +239,21 @@ class Llama4:
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
- yield TokenResult(
- token=next_token[0].item(),
- text=self.tokenizer.decode(next_token.tolist()),
- logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
- )
+
+ results = []
+ for idx, t in enumerate(next_token):
+ results.append(
+ GenerationResult(
+ token=t.item(),
+ text=self.tokenizer.decode([t.item()]),
+ source="output",
+ logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
+ batch_idx=idx,
+ finished=eos_reached[idx],
+ ignore_token=cur_pos < len(prompt_tokens[idx]),
+ )
+ )
+ yield results
prev_pos = cur_pos
if all(eos_reached):
@@ -240,68 +261,47 @@ class Llama4:
def completion(
self,
- content: RawContent,
+ contents: List[RawContent],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
- ) -> Generator:
- llm_input = self.formatter.encode_content(content)
+ ) -> Generator[List[GenerationResult], None, None]:
+ llm_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate(
- llm_input=llm_input,
+ llm_inputs=llm_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
- if result.token in self.tokenizer.stop_tokens:
- break
yield result
+ if all(r.finished for r in result):
+ break
def chat_completion(
self,
- messages: List[RawMessage],
+ messages_batch: List[List[RawMessage]],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
- ) -> Generator:
- llm_input = self.formatter.encode_dialog_prompt(messages)
+ ) -> Generator[List[GenerationResult], None, None]:
+ llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate(
- llm_input=llm_input,
+ llm_inputs=llm_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
- if result.token in self.tokenizer.stop_tokens:
- break
yield result
-
- def chat_completion_raw(
- self,
- messages: List[RawMessage],
- temperature: float = 0.6,
- top_p: float = 0.9,
- max_gen_len: Optional[int] = None,
- logprobs: bool = False,
- ):
- llm_input = self.formatter.encode_dialog_prompt(messages)
- output_tokens = []
- for result in self.generate(
- llm_input=llm_input,
- temperature=temperature,
- top_p=top_p,
- max_gen_len=max_gen_len,
- logprobs=logprobs,
- ):
- output_tokens.append(result.token)
-
- return llm_input.tokens, output_tokens
+ if all(r.finished for r in result):
+ break
def sample_top_p(probs, p):
diff --git a/llama_stack/models/llama/llama4/model.py b/llama_stack/models/llama/llama4/model.py
index a35d6857f..08fac7714 100644
--- a/llama_stack/models/llama/llama4/model.py
+++ b/llama_stack/models/llama/llama4/model.py
@@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# top-level folder for each specific model found within the models/ directory at
-# the top-level of this source tree.
-
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
-
import math
from typing import Any, Dict, List, Optional, Tuple
@@ -184,7 +174,6 @@ class Attention(nn.Module):
self.head_dim,
)
).cuda()
-
self.qk_norm = None
if self.use_qk_norm:
self.qk_norm = L2Norm(args.norm_eps)
diff --git a/llama_stack/models/llama/llama4/moe.py b/llama_stack/models/llama/llama4/moe.py
index 8cecab7dd..2ce49e915 100644
--- a/llama_stack/models/llama/llama4/moe.py
+++ b/llama_stack/models/llama/llama4/moe.py
@@ -100,31 +100,21 @@ class Experts(nn.Module):
class MoE(torch.nn.Module):
"""
- This EC implementation is modified from the original EC module.
- We refactored the token permutation and unpermutation logic and added support to tp and dp2ep sharding.
- This module supports 3 sharding methods of the experts:
- - tp: each TP rank has n_experts experts. Experts are sharded following the conventional row/column-parallel TP sharding.
- - tp2ep: each TP rank has n_experts/tp experts. Experts are not sharded.
- - dp2ep: each EP rank has n_experts/ep experts. Experts are sharded following the row/column-parallel TP sharding.
Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
Several commonly used annotations include:
- a: bsz*slen
- E: number of experts
- e: number of local experts per ep (n_experts/ep)
- - et: number of local experts per tp (n_experts/tp)
- D: hidden dimension
- d: D/tp
- F: model dimension
- - f: F/tp (used in column/row-parallel linear)
- G: number of tokens per expert (a * capacity_factor / E)
- g: number of tokens per expert per TP rank (i.e., G/TP)
- - GG: G*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=False)
- - gg: g*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=True)
Examples:
x_aD [a, D]
routed_in_etG_D [et*G, D]
- x_eGGD: [e, GG, D]
+ x_eGD: [e, G, D]
"""
def __init__(
@@ -207,13 +197,13 @@ class MoE(torch.nn.Module):
routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
out_aD = self.shared_expert(x_aD)
- routed_out_egg_D = self.experts(routed_in_EG_D.detach())
+ routed_out_eg_D = self.experts(routed_in_EG_D.detach())
router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
out_aD.scatter_add_(
dim=0,
index=router_indices_EG_D,
- src=routed_out_egg_D.view(-1, D),
+ src=routed_out_eg_D.view(-1, D),
)
out_aD = reduce_from_model_parallel_region(out_aD)
return out_aD.view(-1, slen, D)
diff --git a/llama_stack/models/llama/llama4/prompts.py b/llama_stack/models/llama/llama4/prompts.py
index d4e48e80a..13b96359a 100644
--- a/llama_stack/models/llama/llama4/prompts.py
+++ b/llama_stack/models/llama/llama4/prompts.py
@@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# top-level folder for each specific model found within the models/ directory at
-# the top-level of this source tree.
-
import textwrap
from io import BytesIO
from pathlib import Path
diff --git a/llama_stack/models/llama/llama4/quantization/loader.py b/llama_stack/models/llama/llama4/quantization/loader.py
index 69aa309fa..f11d83c60 100644
--- a/llama_stack/models/llama/llama4/quantization/loader.py
+++ b/llama_stack/models/llama/llama4/quantization/loader.py
@@ -6,20 +6,29 @@
import logging
import os
-from typing import Optional
+from typing import Callable, Optional
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
-from torch import Tensor
+from torch import Tensor, nn
from torch.nn import functional as F
-from ..generation import QuantizationMode
+from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock
from ..moe import MoE
log = logging.getLogger(__name__)
+def swiglu_wrapper_no_reduce(
+ self,
+ x: Tensor,
+):
+ from ...quantize_impls import ffn_swiglu
+
+ return ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
+
+
def experts_batched_swiglu_wrapper(
self,
x: Tensor, # (e, g, D)
@@ -51,24 +60,30 @@ def convert_to_quantized_model(
rank = get_model_parallel_rank()
+ def should_quantize_block(block: nn.Module) -> bool:
+ if not isinstance(block, TransformerBlock):
+ return False
+
+ is_moe = isinstance(block.feed_forward, MoE)
+ if quantization_mode == QuantizationMode.fp8_mixed:
+ # skip quantization on first and last layers
+ return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
+
+ return is_moe
+
use_rich_progress = use_rich_progress and rank == 0
- progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model)
+ progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block)
if quantization_mode == QuantizationMode.int4_mixed:
int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
- int4_zero_points_path = os.path.join(checkpoint_dir, f"int4_zero_points_{rank}.pt")
if os.path.isfile(int4_scales_path):
log_status(f"Rank {rank}: Loading int4 scales")
int4_scales = torch.load(int4_scales_path, weights_only=True)
- int4_zero_points = torch.load(int4_zero_points_path, weights_only=True)
def apply_quantization(key, weight):
scale = int4_scales[key]
- zero_point = int4_zero_points[key]
return load_int4(
weight,
scale,
- zero_point,
- fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
@@ -76,7 +91,8 @@ def convert_to_quantized_model(
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
def apply_quantization(_, weight):
- return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
+ return quantize_int4(weight, output_device=torch.device("cuda"))
+
else:
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
if os.path.isfile(fp8_scales_path):
@@ -104,33 +120,38 @@ def convert_to_quantized_model(
progress.start()
for _, block in model.named_modules():
- if isinstance(block, TransformerBlock):
- # Skip quantization on first and last layers
- if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
- continue
+ if not should_quantize_block(block):
+ continue
- # Skip quantization on dense layers
- if not isinstance(block.feed_forward, MoE):
- continue
+ update_status(f"Rank {rank} - Layer {block.layer_id}")
- update_status(f"Rank {rank} - Layer {block.layer_id}")
+ # Quantize only routed experts, not shared
+ prefix = f"layers.{block.layer_id}.feed_forward"
+ moe = block.feed_forward
+ moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
- # Quantize only routed experts, not shared
- prefix = f"layers.{block.layer_id}.feed_forward"
- moe = block.feed_forward
- moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
+ for key in ("w1", "w3", "w2"):
+ param = getattr(moe.experts, key)
+ update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
+ setattr(
+ moe.experts,
+ key,
+ apply_quantization(
+ f"{prefix}.experts.{key}",
+ param.transpose(1, 2).contiguous(),
+ ),
+ )
+ if quantization_mode == QuantizationMode.int4_mixed:
+ # Quantize shared experts
+ moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert)
for key in ("w1", "w3", "w2"):
- param = getattr(moe.experts, key)
- update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
- setattr(
- moe.experts,
- key,
- apply_quantization(f"{prefix}.experts.{key}", param.transpose(1, 2).contiguous()),
- )
+ param = getattr(moe.shared_expert, key)
+ update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}")
+ param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight)
- processed_blocks += 1
- update_status(message=None, completed=processed_blocks)
+ processed_blocks += 1
+ update_status(message=None, completed=processed_blocks)
update_status(f"Rank {rank} - Moving parameters to CUDA")
@@ -149,7 +170,12 @@ def convert_to_quantized_model(
# fp8/int4 loading can be very slow so we add progress bars to make life slightly better
-def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer):
+def logging_callbacks(
+ use_rich_progress: bool,
+ rank: int,
+ model: Transformer,
+ should_quantize_block: Callable[[nn.Module], bool],
+):
console = None
if use_rich_progress:
from rich.console import Console
@@ -162,15 +188,7 @@ def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer):
elif rank == 0: # Only log from rank 0 for non-rich logging
log.info(message)
- total_blocks = sum(
- 1
- for _, block in model.named_modules()
- if (
- isinstance(block, TransformerBlock)
- and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
- and isinstance(block.feed_forward, MoE)
- )
- )
+ total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
progress = None
if use_rich_progress:
from rich.progress import (
diff --git a/llama_stack/models/llama/llama4/tokenizer.py b/llama_stack/models/llama/llama4/tokenizer.py
index c1347daca..14250f681 100644
--- a/llama_stack/models/llama/llama4/tokenizer.py
+++ b/llama_stack/models/llama/llama4/tokenizer.py
@@ -4,6 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# top-level folder for each specific model found within the models/ directory at
+# the top-level of this source tree.
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
@@ -59,8 +66,6 @@ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
"<|text_post_train_reserved_special_token_3|>",
"<|text_post_train_reserved_special_token_4|>",
"<|text_post_train_reserved_special_token_5|>",
- "<|python_start|>",
- "<|python_end|>",
"<|finetune_right_pad|>",
] + get_reserved_special_tokens(
"text_post_train", 61, 6
@@ -85,8 +90,23 @@ LLAMA4_VISION_SPECIAL_TOKENS = [
"vision", 1041, 7
) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|>
+# 201134, ..., 201143
+LLAMA4_REASONING_SPECIAL_TOKENS = [
+ "<|reasoning_reserved_special_token_0|>",
+ "<|reasoning_reserved_special_token_1|>",
+ "<|reasoning_reserved_special_token_2|>",
+ "<|reasoning_reserved_special_token_3|>",
+ "<|reasoning_reserved_special_token_4|>",
+ "<|reasoning_reserved_special_token_5|>",
+ "<|reasoning_reserved_special_token_6|>",
+ "<|reasoning_reserved_special_token_7|>",
+ "<|reasoning_thinking_start|>",
+ "<|reasoning_thinking_end|>",
+]
-LLAMA4_SPECIAL_TOKENS = LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS
+LLAMA4_SPECIAL_TOKENS = (
+ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS + LLAMA4_REASONING_SPECIAL_TOKENS
+)
BASIC_SPECIAL_TOKENS = [
"<|begin_of_text|>",
@@ -155,6 +175,9 @@ class Tokenizer:
self.eot_id: int = self.special_tokens["<|eot|>"]
self.eom_id: int = self.special_tokens["<|eom|>"]
+ self.thinking_start_id: int = self.special_tokens["<|reasoning_thinking_start|>"]
+ self.thinking_end_id: int = self.special_tokens["<|reasoning_thinking_end|>"]
+
self.stop_tokens = [
self.eos_id,
self.special_tokens["<|eom|>"],
diff --git a/llama_stack/models/llama/llama4/vision/embedding.py b/llama_stack/models/llama/llama4/vision/embedding.py
index 73b29cbef..ed7659a73 100644
--- a/llama_stack/models/llama/llama4/vision/embedding.py
+++ b/llama_stack/models/llama/llama4/vision/embedding.py
@@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# top-level folder for each specific model found within the models/ directory at
-# the top-level of this source tree.
-
import math
from typing import Any, Callable, Dict, List
diff --git a/llama_stack/models/llama/prompt_format.py b/llama_stack/models/llama/prompt_format.py
index 695c0bf74..6756aebfe 100644
--- a/llama_stack/models/llama/prompt_format.py
+++ b/llama_stack/models/llama/prompt_format.py
@@ -28,9 +28,6 @@ from llama_stack.models.llama.datatypes import (
ToolPromptFormat,
)
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
-from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
- LLMInput,
-)
from .llama3.interface import LLama31Interface
from .llama3.template_data import (
@@ -38,6 +35,7 @@ from .llama3.template_data import (
system_message_builtin_tools_only,
system_message_custom_tools_only,
)
+from .llama4.datatypes import LLMInput
class TextCompletionContent(BaseModel):
diff --git a/llama_stack/models/llama/sku_list.py b/llama_stack/models/llama/sku_list.py
index dd3144bb0..513481831 100644
--- a/llama_stack/models/llama/sku_list.py
+++ b/llama_stack/models/llama/sku_list.py
@@ -4,24 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# top-level folder for each specific model found within the models/ directory at
-# the top-level of this source tree.
-
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional
-from .datatypes import (
+from .sku_types import (
CheckpointQuantizationFormat,
CoreModelId,
Model,
ModelFamily,
- SamplingParams,
- TopPSamplingStrategy,
)
LLAMA2_VOCAB_SIZE = 32000
@@ -47,15 +38,6 @@ def all_registered_models() -> List[Model]:
)
-def recommended_sampling_params() -> SamplingParams:
- return SamplingParams(
- strategy=TopPSamplingStrategy(
- temperature=1.0,
- top_p=0.9,
- )
- )
-
-
def llama2_family() -> List[Model]:
return [
*llama2_base_models(),
@@ -150,7 +132,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_7b,
description="Llama 2 7b model",
huggingface_repo="meta-llama/Llama-2-7b",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@@ -169,7 +150,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_13b,
description="Llama 2 13b model",
huggingface_repo="meta-llama/Llama-2-13b",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 5120,
"n_layers": 40,
@@ -188,7 +168,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_70b,
description="Llama 2 70b model",
huggingface_repo="meta-llama/Llama-2-70b",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@@ -230,7 +209,6 @@ def llama3_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_70b,
description="Llama 3 70b model",
huggingface_repo="meta-llama/Llama-3-70B",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@@ -254,7 +232,6 @@ def llama3_1_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_8b,
description="Llama 3.1 8b model",
huggingface_repo="meta-llama/Llama-3.1-8B",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@@ -273,7 +250,6 @@ def llama3_1_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_70b,
description="Llama 3.1 70b model",
huggingface_repo="meta-llama/Llama-3.1-70B",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@@ -293,7 +269,6 @@ def llama3_1_base_models() -> List[Model]:
variant="bf16-mp8",
description="Llama 3.1 405b model (BF16 weights)",
huggingface_repo="meta-llama/Llama-3.1-405B",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@@ -313,7 +288,6 @@ def llama3_1_base_models() -> List[Model]:
description="Llama 3.1 405b model (FP8 quantized)",
huggingface_repo="meta-llama/Llama-3.1-405B-FP8",
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@@ -333,7 +307,6 @@ def llama3_1_base_models() -> List[Model]:
variant="bf16-mp16",
description="Llama 3.1 405b model (BF16 weights for mp16)",
huggingface_repo="meta-llama/Llama-3.1-405B",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@@ -357,7 +330,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_1b,
description="Llama 3.2 1b model",
huggingface_repo="meta-llama/Llama-3.2-1B",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 2048,
"n_layers": 16,
@@ -376,7 +348,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_3b,
description="Llama 3.2 3b model",
huggingface_repo="meta-llama/Llama-3.2-3B",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 3072,
"n_layers": 28,
@@ -395,7 +366,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_11b_vision,
description="Llama 3.2 11b vision model",
huggingface_repo="meta-llama/Llama-3.2-11B-Vision",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@@ -417,7 +387,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_90b_vision,
description="Llama 3.2 90b vision model",
huggingface_repo="meta-llama/Llama-3.2-90B-Vision",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@@ -444,7 +413,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_7b_chat,
description="Llama 2 7b chat model",
huggingface_repo="meta-llama/Llama-2-7b-chat",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@@ -463,7 +431,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_13b_chat,
description="Llama 2 13b chat model",
huggingface_repo="meta-llama/Llama-2-13b-chat",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 5120,
"n_layers": 40,
@@ -482,7 +449,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_70b_chat,
description="Llama 2 70b chat model",
huggingface_repo="meta-llama/Llama-2-70b-chat",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@@ -506,7 +472,6 @@ def llama3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_8b_instruct,
description="Llama 3 8b instruct model",
huggingface_repo="meta-llama/Llama-3-8B-Instruct",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@@ -525,7 +490,6 @@ def llama3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_70b_instruct,
description="Llama 3 70b instruct model",
huggingface_repo="meta-llama/Llama-3-70B-Instruct",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@@ -549,7 +513,6 @@ def llama3_1_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_8b_instruct,
description="Llama 3.1 8b instruct model",
huggingface_repo="meta-llama/Llama-3.1-8B-Instruct",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@@ -568,7 +531,6 @@ def llama3_1_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_70b_instruct,
description="Llama 3.1 70b instruct model",
huggingface_repo="meta-llama/Llama-3.1-70B-Instruct",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@@ -588,7 +550,6 @@ def llama3_1_instruct_models() -> List[Model]:
variant="bf16-mp8",
description="Llama 3.1 405b instruct model (BF16 weights)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@@ -608,7 +569,6 @@ def llama3_1_instruct_models() -> List[Model]:
description="Llama 3.1 405b instruct model (FP8 quantized)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8",
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@@ -628,7 +588,6 @@ def llama3_1_instruct_models() -> List[Model]:
variant="bf16-mp16",
description="Llama 3.1 405b instruct model (BF16 weights for mp16)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@@ -684,7 +643,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 1b INT4 quantized LoRA",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
**arch_args_1b(),
"quantization_args": {
@@ -703,7 +661,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 1b INT4 quantized SpinQuant",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
**arch_args_1b(),
"quantization_args": {
@@ -718,7 +675,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 3b INT4 quantized LoRA",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
**arch_args_3b(),
"quantization_args": {
@@ -737,7 +693,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 3b INT4 quantized SpinQuant",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
**arch_args_3b(),
"quantization_args": {
@@ -755,7 +710,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_1b_instruct,
description="Llama 3.2 1b instruct model",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct",
- recommended_sampling_params=recommended_sampling_params(),
arch_args=arch_args_1b(),
pth_file_count=1,
),
@@ -763,7 +717,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_3b_instruct,
description="Llama 3.2 3b instruct model",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct",
- recommended_sampling_params=recommended_sampling_params(),
arch_args=arch_args_3b(),
pth_file_count=1,
),
@@ -772,7 +725,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_11b_vision_instruct,
description="Llama 3.2 11b vision instruct model",
huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@@ -794,7 +746,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_90b_vision_instruct,
description="Llama 3.2 90b vision instruct model",
huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@@ -821,7 +772,6 @@ def llama3_3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_3_70b_instruct,
description="Llama 3.3 70b instruct",
huggingface_repo="meta-llama/Llama-3.3-70B-Instruct",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@@ -846,7 +796,6 @@ def safety_models() -> List[Model]:
core_model_id=CoreModelId.llama_guard_3_11b_vision,
description="Llama Guard v3 11b vision system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@@ -870,7 +819,6 @@ def safety_models() -> List[Model]:
description="Llama Guard v3 1b 'int4' quantized system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4",
quantization_format=CheckpointQuantizationFormat.int4,
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 2048,
"n_layers": 12,
@@ -888,7 +836,6 @@ def safety_models() -> List[Model]:
core_model_id=CoreModelId.llama_guard_3_1b,
description="Llama Guard v3 1b system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-1B",
- recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 2048,
"n_layers": 16,
diff --git a/llama_stack/models/llama/sku_types.py b/llama_stack/models/llama/sku_types.py
new file mode 100644
index 000000000..88799b66d
--- /dev/null
+++ b/llama_stack/models/llama/sku_types.py
@@ -0,0 +1,229 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from enum import Enum
+from typing import Any, Dict, Optional
+
+from pydantic import BaseModel, ConfigDict, Field
+
+
+class CheckpointQuantizationFormat(Enum):
+ # default format
+ bf16 = "bf16"
+
+ # used for enabling fp8_rowwise inference, some weights are bf16
+ fp8_mixed = "fp8-mixed"
+
+ int8 = "int8"
+
+ int4 = "int4"
+
+
+class ModelFamily(Enum):
+ llama2 = "llama2"
+ llama3 = "llama3"
+ llama3_1 = "llama3_1"
+ llama3_2 = "llama3_2"
+ llama3_3 = "llama3_3"
+ llama4 = "llama4"
+ safety = "safety"
+
+
+class CoreModelId(Enum):
+ """Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
+
+ # Llama 2 family
+ llama2_7b = "Llama-2-7b"
+ llama2_13b = "Llama-2-13b"
+ llama2_70b = "Llama-2-70b"
+ llama2_7b_chat = "Llama-2-7b-chat"
+ llama2_13b_chat = "Llama-2-13b-chat"
+ llama2_70b_chat = "Llama-2-70b-chat"
+
+ # Llama 3 family
+ llama3_8b = "Llama-3-8B"
+ llama3_70b = "Llama-3-70B"
+ llama3_8b_instruct = "Llama-3-8B-Instruct"
+ llama3_70b_instruct = "Llama-3-70B-Instruct"
+
+ # Llama 3.1 family
+ llama3_1_8b = "Llama3.1-8B"
+ llama3_1_70b = "Llama3.1-70B"
+ llama3_1_405b = "Llama3.1-405B"
+ llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
+ llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
+ llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
+
+ # Llama 3.2 family
+ llama3_2_1b = "Llama3.2-1B"
+ llama3_2_3b = "Llama3.2-3B"
+ llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
+ llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
+ llama3_2_11b_vision = "Llama3.2-11B-Vision"
+ llama3_2_90b_vision = "Llama3.2-90B-Vision"
+ llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
+ llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
+
+ # Llama 3.3 family
+ llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
+
+ # Llama 4 family
+ llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
+ llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
+ llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
+ llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
+
+ # Safety models
+ llama_guard_3_8b = "Llama-Guard-3-8B"
+ llama_guard_2_8b = "Llama-Guard-2-8B"
+ llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
+ llama_guard_3_1b = "Llama-Guard-3-1B"
+
+
+def is_multimodal(model_id) -> bool:
+ if model_id in [
+ CoreModelId.llama3_2_11b_vision,
+ CoreModelId.llama3_2_90b_vision,
+ CoreModelId.llama3_2_11b_vision_instruct,
+ CoreModelId.llama3_2_90b_vision_instruct,
+ ]:
+ return True
+ else:
+ return False
+
+
+def model_family(model_id) -> ModelFamily:
+ if model_id in [
+ CoreModelId.llama2_7b,
+ CoreModelId.llama2_13b,
+ CoreModelId.llama2_70b,
+ CoreModelId.llama2_7b_chat,
+ CoreModelId.llama2_13b_chat,
+ CoreModelId.llama2_70b_chat,
+ ]:
+ return ModelFamily.llama2
+ elif model_id in [
+ CoreModelId.llama3_8b,
+ CoreModelId.llama3_70b,
+ CoreModelId.llama3_8b_instruct,
+ CoreModelId.llama3_70b_instruct,
+ ]:
+ return ModelFamily.llama3
+ elif model_id in [
+ CoreModelId.llama3_1_8b,
+ CoreModelId.llama3_1_70b,
+ CoreModelId.llama3_1_405b,
+ CoreModelId.llama3_1_8b_instruct,
+ CoreModelId.llama3_1_70b_instruct,
+ CoreModelId.llama3_1_405b_instruct,
+ ]:
+ return ModelFamily.llama3_1
+ elif model_id in [
+ CoreModelId.llama3_2_1b,
+ CoreModelId.llama3_2_3b,
+ CoreModelId.llama3_2_1b_instruct,
+ CoreModelId.llama3_2_3b_instruct,
+ CoreModelId.llama3_2_11b_vision,
+ CoreModelId.llama3_2_90b_vision,
+ CoreModelId.llama3_2_11b_vision_instruct,
+ CoreModelId.llama3_2_90b_vision_instruct,
+ ]:
+ return ModelFamily.llama3_2
+ elif model_id in [
+ CoreModelId.llama3_3_70b_instruct,
+ ]:
+ return ModelFamily.llama3_3
+ elif model_id in [
+ CoreModelId.llama4_scout_17b_16e,
+ CoreModelId.llama4_scout_17b_16e_instruct,
+ CoreModelId.llama4_maverick_17b_128e,
+ CoreModelId.llama4_maverick_17b_128e_instruct,
+ ]:
+ return ModelFamily.llama4
+ elif model_id in [
+ CoreModelId.llama_guard_3_8b,
+ CoreModelId.llama_guard_2_8b,
+ CoreModelId.llama_guard_3_11b_vision,
+ CoreModelId.llama_guard_3_1b,
+ ]:
+ return ModelFamily.safety
+ else:
+ raise ValueError(f"Unknown model family for {model_id}")
+
+
+class Model(BaseModel):
+ core_model_id: CoreModelId
+ description: str
+ huggingface_repo: Optional[str] = None
+ arch_args: Dict[str, Any]
+ variant: str = ""
+
+ quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
+ pth_file_count: int
+ metadata: Dict[str, Any] = Field(default_factory=dict)
+
+ # silence pydantic until we remove the `model_` fields
+ model_config = ConfigDict(protected_namespaces=())
+
+ @property
+ def model_family(self) -> ModelFamily:
+ return model_family(self.core_model_id)
+
+ # The SKU is uniquely identified by (model_id, variant) combo
+ def descriptor(self, shorten_default_variant: bool = True) -> str:
+ if not self.variant:
+ return self.core_model_id.value
+ return f"{self.core_model_id.value}:{self.variant}"
+
+ @property
+ def is_instruct_model(self) -> bool:
+ return "instruct" in self.core_model_id.value
+
+ # Featured models are shown in the non-exhaustive model list
+ @property
+ def is_featured(self) -> bool:
+ return self.model_family in [
+ ModelFamily.llama3_1,
+ ModelFamily.llama3_2,
+ ModelFamily.llama3_3,
+ ModelFamily.llama4,
+ ModelFamily.safety,
+ ]
+
+ @property
+ def max_seq_length(self) -> int:
+ if self.model_family == ModelFamily.llama2:
+ return 4096
+ elif self.core_model_id == CoreModelId.llama_guard_2_8b:
+ return 4096
+ elif self.model_family == ModelFamily.llama3:
+ return 8192
+ elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
+ return 131072
+ elif self.model_family == ModelFamily.llama3_2:
+ if self.quantization_format == CheckpointQuantizationFormat.int4:
+ return 8192
+ return 131072
+ elif self.model_family == ModelFamily.llama4:
+ if self.core_model_id in {
+ CoreModelId.llama4_scout_17b_16e,
+ CoreModelId.llama4_maverick_17b_128e,
+ }:
+ return 262144
+ if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
+ return 10485760
+ if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
+ return 1048576
+
+ raise AssertionError(f"Unexpected core model id: {self.core_model_id}")
+ elif self.core_model_id in [
+ CoreModelId.llama_guard_3_8b,
+ CoreModelId.llama_guard_3_11b_vision,
+ CoreModelId.llama_guard_3_1b,
+ ]:
+ return 131072
+ else:
+ raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
index e1af4ab71..6840da89f 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
@@ -52,6 +52,7 @@ from llama_stack.apis.inference import (
StopReason,
SystemMessage,
ToolDefinition,
+ ToolParamDefinition,
ToolResponse,
ToolResponseMessage,
UserMessage,
@@ -63,7 +64,6 @@ from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolCall,
- ToolParamDefinition,
)
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.telemetry import tracing
diff --git a/llama_stack/providers/inline/inference/meta_reference/generators.py b/llama_stack/providers/inline/inference/meta_reference/generators.py
index 809351164..10e597665 100644
--- a/llama_stack/providers/inline/inference/meta_reference/generators.py
+++ b/llama_stack/providers/inline/inference/meta_reference/generators.py
@@ -12,20 +12,19 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
from llama_stack.apis.inference import (
Fp8QuantizationConfig,
+ GreedySamplingStrategy,
Int4QuantizationConfig,
JsonSchemaResponseFormat,
ResponseFormat,
-)
-from llama_stack.models.llama.datatypes import (
- GreedySamplingStrategy,
- Model,
SamplingParams,
TopPSamplingStrategy,
)
+from llama_stack.models.llama.datatypes import QuantizationMode
from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
+from llama_stack.models.llama.sku_types import Model
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
@@ -136,9 +135,9 @@ class Llama4Generator:
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
- quantization_mode = "fp8_mixed"
+ quantization_mode = QuantizationMode.fp8_mixed
elif isinstance(config.quantization, Int4QuantizationConfig):
- quantization_mode = "int4_mixed"
+ quantization_mode = QuantizationMode.int4_mixed
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
@@ -225,9 +224,9 @@ class Llama3Generator:
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
- quantization_mode = "fp8_mixed"
+ quantization_mode = QuantizationMode.fp8_mixed
elif isinstance(config.quantization, Int4QuantizationConfig):
- quantization_mode = "int4_mixed"
+ quantization_mode = QuantizationMode.int4_mixed
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
@@ -240,6 +239,9 @@ class Llama3Generator:
world_size=llama_model.pth_file_count,
quantization_mode=quantization_mode,
)
+ self.tokenizer = self.inner_generator.tokenizer
+ self.args = self.inner_generator.args
+ self.formatter = self.inner_generator.formatter
def completion(
self,
diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py
index da217728b..ca2f51ac7 100644
--- a/llama_stack/providers/inline/inference/meta_reference/inference.py
+++ b/llama_stack/providers/inline/inference/meta_reference/inference.py
@@ -31,23 +31,21 @@ from llama_stack.apis.inference import (
LogProbConfig,
Message,
ResponseFormat,
+ SamplingParams,
+ StopReason,
TokenLogProbs,
ToolChoice,
ToolConfig,
-)
-from llama_stack.apis.models import Model, ModelType
-from llama_stack.models.llama.datatypes import (
- ModelFamily,
- SamplingParams,
- StopReason,
ToolDefinition,
ToolPromptFormat,
)
+from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
+from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
diff --git a/llama_stack/providers/inline/inference/vllm/openai_utils.py b/llama_stack/providers/inline/inference/vllm/openai_utils.py
index 90b5398f9..d34f5ad5f 100644
--- a/llama_stack/providers/inline/inference/vllm/openai_utils.py
+++ b/llama_stack/providers/inline/inference/vllm/openai_utils.py
@@ -14,9 +14,10 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
Message,
ToolChoice,
+ ToolDefinition,
UserMessage,
)
-from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
+from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py
index 256e0f821..ea2643b7a 100644
--- a/llama_stack/providers/inline/inference/vllm/vllm.py
+++ b/llama_stack/providers/inline/inference/vllm/vllm.py
@@ -46,6 +46,8 @@ from llama_stack.apis.inference import (
TokenLogProbs,
ToolChoice,
ToolConfig,
+ TopKSamplingStrategy,
+ TopPSamplingStrategy,
)
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
@@ -55,8 +57,6 @@ from llama_stack.models.llama.datatypes import (
ToolCall,
ToolDefinition,
ToolPromptFormat,
- TopKSamplingStrategy,
- TopPSamplingStrategy,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py
index f8a1c0436..a040ca1b0 100644
--- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py
+++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py
@@ -22,8 +22,8 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform
from llama_stack.apis.post_training import DatasetFormat
-from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import resolve_model
+from llama_stack.models.llama.sku_types import Model
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py
index e514e3781..d95c40976 100644
--- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py
+++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py
@@ -23,7 +23,8 @@ from llama_stack.apis.safety import (
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.datatypes import Api
-from llama_stack.models.llama.datatypes import CoreModelId, Role
+from llama_stack.models.llama.datatypes import Role
+from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
diff --git a/llama_stack/providers/remote/inference/bedrock/models.py b/llama_stack/providers/remote/inference/bedrock/models.py
index c5079799f..ec8120049 100644
--- a/llama_stack/providers/remote/inference/bedrock/models.py
+++ b/llama_stack/providers/remote/inference/bedrock/models.py
@@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from llama_stack.models.llama.datatypes import CoreModelId
+from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py
index a53e6e5a5..43d986b86 100644
--- a/llama_stack/providers/remote/inference/cerebras/cerebras.py
+++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py
@@ -28,8 +28,8 @@ from llama_stack.apis.inference import (
ToolConfig,
ToolDefinition,
ToolPromptFormat,
+ TopKSamplingStrategy,
)
-from llama_stack.models.llama.datatypes import TopKSamplingStrategy
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
diff --git a/llama_stack/providers/remote/inference/cerebras/models.py b/llama_stack/providers/remote/inference/cerebras/models.py
index 37419bf4c..38301b32a 100644
--- a/llama_stack/providers/remote/inference/cerebras/models.py
+++ b/llama_stack/providers/remote/inference/cerebras/models.py
@@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from llama_stack.models.llama.datatypes import CoreModelId
+from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py
index 53a9c04f4..0eaf0135b 100644
--- a/llama_stack/providers/remote/inference/databricks/databricks.py
+++ b/llama_stack/providers/remote/inference/databricks/databricks.py
@@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
-from llama_stack.models.llama.datatypes import CoreModelId
+from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
diff --git a/llama_stack/providers/remote/inference/fireworks/models.py b/llama_stack/providers/remote/inference/fireworks/models.py
index a0dc11768..4975d061f 100644
--- a/llama_stack/providers/remote/inference/fireworks/models.py
+++ b/llama_stack/providers/remote/inference/fireworks/models.py
@@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
-from llama_stack.models.llama.datatypes import CoreModelId
+from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
diff --git a/llama_stack/providers/remote/inference/nvidia/models.py b/llama_stack/providers/remote/inference/nvidia/models.py
index 879855003..964125148 100644
--- a/llama_stack/providers/remote/inference/nvidia/models.py
+++ b/llama_stack/providers/remote/inference/nvidia/models.py
@@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
-from llama_stack.models.llama.datatypes import CoreModelId
+from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py
index 5caf19fda..e1f5d7a6a 100644
--- a/llama_stack/providers/remote/inference/nvidia/nvidia.py
+++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py
@@ -29,15 +29,13 @@ from llama_stack.apis.inference import (
LogProbConfig,
Message,
ResponseFormat,
+ SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
-)
-from llama_stack.models.llama.datatypes import (
- SamplingParams,
ToolDefinition,
- ToolPromptFormat,
)
+from llama_stack.models.llama.datatypes import ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py
index 0582cb816..3f2769b26 100644
--- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py
+++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py
@@ -19,11 +19,9 @@ from llama_stack.apis.inference import (
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
+ GreedySamplingStrategy,
JsonSchemaResponseFormat,
TokenLogProbs,
-)
-from llama_stack.models.llama.datatypes import (
- GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
diff --git a/llama_stack/providers/remote/inference/ollama/models.py b/llama_stack/providers/remote/inference/ollama/models.py
index be556762c..42e364105 100644
--- a/llama_stack/providers/remote/inference/ollama/models.py
+++ b/llama_stack/providers/remote/inference/ollama/models.py
@@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
-from llama_stack.models.llama.datatypes import CoreModelId
+from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
diff --git a/llama_stack/providers/remote/inference/sambanova/models.py b/llama_stack/providers/remote/inference/sambanova/models.py
index 2231be22d..9589ea268 100644
--- a/llama_stack/providers/remote/inference/sambanova/models.py
+++ b/llama_stack/providers/remote/inference/sambanova/models.py
@@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from llama_stack.models.llama.datatypes import CoreModelId
+from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py
index 635a42d38..a3badd468 100644
--- a/llama_stack/providers/remote/inference/sambanova/sambanova.py
+++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py
@@ -21,6 +21,7 @@ from llama_stack.apis.inference import (
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
+ GreedySamplingStrategy,
Inference,
LogProbConfig,
Message,
@@ -35,12 +36,9 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
ToolResponseMessage,
- UserMessage,
-)
-from llama_stack.models.llama.datatypes import (
- GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
+ UserMessage,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
diff --git a/llama_stack/providers/remote/inference/together/models.py b/llama_stack/providers/remote/inference/together/models.py
index 63d3d94b5..f014c03f0 100644
--- a/llama_stack/providers/remote/inference/together/models.py
+++ b/llama_stack/providers/remote/inference/together/models.py
@@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
-from llama_stack.models.llama.datatypes import CoreModelId
+from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index 6a828322f..e2f3a7b33 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -96,7 +96,6 @@ def _convert_to_vllm_tool_calls_in_response(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
- arguments_json=call.function.arguments,
)
for call in tool_calls
]
@@ -176,7 +175,6 @@ async def _process_vllm_chat_completion_stream_response(
call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name,
arguments=args,
- arguments_json=args_str,
),
parse_status=ToolCallParseStatus.succeeded,
),
diff --git a/llama_stack/providers/remote/post_training/nvidia/models.py b/llama_stack/providers/remote/post_training/nvidia/models.py
index 04a9af38c..7c696ac20 100644
--- a/llama_stack/providers/remote/post_training/nvidia/models.py
+++ b/llama_stack/providers/remote/post_training/nvidia/models.py
@@ -6,7 +6,7 @@
from typing import List
-from llama_stack.models.llama.datatypes import CoreModelId
+from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
diff --git a/llama_stack/providers/tests/report.py b/llama_stack/providers/tests/report.py
index c9a7f69a8..bc29534be 100644
--- a/llama_stack/providers/tests/report.py
+++ b/llama_stack/providers/tests/report.py
@@ -12,8 +12,8 @@ import pytest
from pytest import ExitCode
from pytest_html.basereport import _process_outcome
-from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_list import all_registered_models
+from llama_stack.models.llama.sku_types import CoreModelId
INFERENCE_APIS = ["chat_completion"]
FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"]
diff --git a/llama_stack/providers/utils/inference/__init__.py b/llama_stack/providers/utils/inference/__init__.py
index a885da235..e36be9404 100644
--- a/llama_stack/providers/utils/inference/__init__.py
+++ b/llama_stack/providers/utils/inference/__init__.py
@@ -6,8 +6,8 @@
from typing import List
-from llama_stack.models.llama.datatypes import * # noqa: F403
from llama_stack.models.llama.sku_list import all_registered_models
+from llama_stack.models.llama.sku_types import * # noqa: F403
def is_supported_safety_model(model: Model) -> bool:
diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py
index e475d77b6..44a89dfb0 100644
--- a/llama_stack/providers/utils/inference/openai_compat.py
+++ b/llama_stack/providers/utils/inference/openai_compat.py
@@ -73,21 +73,21 @@ from llama_stack.apis.inference import (
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
+ GreedySamplingStrategy,
Message,
+ SamplingParams,
SystemMessage,
TokenLogProbs,
ToolResponseMessage,
+ TopKSamplingStrategy,
+ TopPSamplingStrategy,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
BuiltinTool,
- GreedySamplingStrategy,
- SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
- TopKSamplingStrategy,
- TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,
diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py
index 0231312cc..4f9c4927a 100644
--- a/llama_stack/providers/utils/inference/prompt_adapter.py
+++ b/llama_stack/providers/utils/inference/prompt_adapter.py
@@ -34,7 +34,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
- ModelFamily,
RawContent,
RawContentItem,
RawMediaItem,
@@ -43,7 +42,6 @@ from llama_stack.models.llama.datatypes import (
Role,
StopReason,
ToolPromptFormat,
- is_multimodal,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import (
@@ -55,6 +53,7 @@ from llama_stack.models.llama.llama3.prompt_templates import (
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
+from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.utils.inference import supported_inference_models
log = get_logger(name=__name__, category="inference")
diff --git a/pyproject.toml b/pyproject.toml
index 8d8ff4338..8ae7ddbb6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -224,9 +224,9 @@ exclude = [
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/config\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
- "^llama_stack/providers/inline/inference/meta_reference/llama3/generation\\.py$",
- "^llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model\\.py$",
- "^llama_stack/providers/inline/inference/meta_reference/llama4/",
+ "^llama_stack/models/llama/llama3/generation\\.py$",
+ "^llama_stack/models/llama/llama3/multimodal/model\\.py$",
+ "^llama_stack/models/llama/llama4/",
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$",
diff --git a/tests/integration/report.py b/tests/integration/report.py
index c07338ce6..a50f51d3f 100644
--- a/tests/integration/report.py
+++ b/tests/integration/report.py
@@ -11,7 +11,6 @@ import pytest
from pytest import CollectReport
from termcolor import cprint
-from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_list import (
all_registered_models,
llama3_1_instruct_models,
@@ -20,6 +19,7 @@ from llama_stack.models.llama.sku_list import (
llama3_instruct_models,
safety_models,
)
+from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import Api
from .metadata import API_MAPS