several fixes

This commit is contained in:
Ashwin Bharambe 2025-04-07 10:31:20 -07:00
parent e2e2820c9a
commit 53a8086e37
60 changed files with 1006 additions and 1078 deletions

View file

@ -4163,80 +4163,70 @@
]
},
"arguments": {
"oneOf": [
{
"type": "string"
},
{
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
}
]
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
},
{
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
}
},
{
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
}
]
]
}
},
{
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "string"
},
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "boolean"
},
{
"type": "null"
}
}
]
]
}
}
}
]
},
"arguments_json": {
"type": "string"
]
}
}
},
"additionalProperties": false,

View file

@ -2890,34 +2890,30 @@ components:
title: BuiltinTool
- type: string
arguments:
oneOf:
- type: string
- type: object
additionalProperties:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
- type: array
items:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
- type: object
additionalProperties:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
arguments_json:
type: string
type: object
additionalProperties:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
- type: array
items:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
- type: object
additionalProperties:
oneOf:
- type: string
- type: integer
- type: number
- type: boolean
- type: 'null'
additionalProperties: false
required:
- call_id

View file

@ -25,15 +25,64 @@ from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import (
BuiltinTool,
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
register_schema(ToolCall)
register_schema(ToolParamDefinition)
register_schema(ToolDefinition)
@json_schema_type
class GreedySamplingStrategy(BaseModel):
type: Literal["greedy"] = "greedy"
@json_schema_type
class TopPSamplingStrategy(BaseModel):
type: Literal["top_p"] = "top_p"
temperature: Optional[float] = Field(..., gt=0.0)
top_p: Optional[float] = 0.95
@json_schema_type
class TopKSamplingStrategy(BaseModel):
type: Literal["top_k"] = "top_k"
top_k: int = Field(..., ge=1)
SamplingStrategy = Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
]
register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type
class SamplingParams(BaseModel):
"""Sampling parameters.
:param strategy: The sampling strategy.
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
your prompt plus max_tokens cannot exceed the model's context length.
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
:param stop: Up to 4 sequences where the API will stop generating further tokens.
The returned text will not contain the stop sequence.
"""
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
class LogProbConfig(BaseModel):
"""

View file

@ -29,8 +29,8 @@ from rich.progress import (
from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
from llama_stack.models.llama.sku_types import Model
class Download(Subcommand):

View file

@ -63,17 +63,6 @@ class ModelDescribe(Subcommand):
("Model params.json", json.dumps(model.arch_args, indent=4)),
]
if model.recommended_sampling_params is not None:
sampling_params = model.recommended_sampling_params.model_dump()
for k in ("max_tokens", "repetition_penalty"):
del sampling_params[k]
rows.append(
(
"Recommended sampling params",
json.dumps(sampling_params, indent=4),
)
)
print_table(
rows,
headers,

View file

@ -11,7 +11,7 @@ from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
from llama_stack.models.llama.sku_types import CoreModelId, ModelFamily, is_multimodal, model_family
ROOT_DIR = Path(__file__).parent.parent.parent

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any, Dict
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
class PromptGuardModel(BaseModel):
@ -23,7 +23,6 @@ class PromptGuardModel(BaseModel):
is_instruct_model: bool = False
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
arch_args: Dict[str, Any] = Field(default_factory=dict)
recommended_sampling_params: Optional[SamplingParams] = None
def descriptor(self) -> str:
return self.model_id

View file

@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import concurrent.futures
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
if new_mp_size % old_mp_size == 0:
# Read old MP shard and split it into smaller ones
return [new_mp_rank * old_mp_size // new_mp_size]
elif old_mp_size % new_mp_size == 0:
# Merge old MP shards into a single one
mp_factor = old_mp_size // new_mp_size
return list(range(new_mp_rank * mp_factor, (new_mp_rank + 1) * mp_factor))
else:
raise ValueError(
f"Either old MP size or new MP size should be a multiple of the other: "
f"{old_mp_size} % {new_mp_size} != 0 and {new_mp_size} % {old_mp_size} != 0"
)
def maybe_reshard_state_dict(
ckpt_paths: List[Path],
n_kv_heads: int,
moe_num_experts: Optional[int] = None,
map_location: Union[str, torch.device] = "cpu",
mmap: bool = True,
) -> Dict[str, torch.Tensor]:
if str(map_location) == "cpu":
torch.set_default_tensor_type(torch.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
ckpt_paths = np.array(sorted(ckpt_paths))
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
old_mp_size = len(ckpt_paths)
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
paths = ckpt_paths[old_mp_ranks] # type: ignore
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
if new_mp_size == old_mp_size:
return state_dicts[0] # type: ignore
if moe_num_experts is not None:
state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts]
print(f"Resharding {len(state_dicts)} state dicts from MP size {old_mp_size} to MP size {new_mp_size}")
return reshard_mp(
state_dicts,
size=max(new_mp_size // old_mp_size, 1),
rank=new_mp_rank % max(new_mp_size // old_mp_size, 1),
repeat_qk_qv=max(new_mp_size // n_kv_heads, 1),
)
_WEIGHT_ROW_KEY = {
"feed_forward.w2",
"feed_forward.mlp.fc2",
"attention.wo",
"feed_forward.mlp.fc2_weight",
"feed_forward.w_out_shared_DF.weight",
"attn.wo.weight",
"mlp.c_proj.weight",
}
_MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"}
_WEIGHT_COLUMN_KEY = {
"output",
"feed_forward.(w1|w3)",
"feed_forward.mlp.(fc1|fc3)",
"feed_forward.mlp.fc1_weight",
"attention.(wk|wq|wv|wqkv).weight",
"feed_forward.(w_in_shared_FD|w_swiglu_FD)",
"attn.(wk|wq|wv).weight",
"attn.(wk|wq|wv).bias",
"mlp.c_fc.weight",
"mlp.c_fc.bias",
"conv1._linear.weight",
"tok_embeddings.weight",
"vision_projection.weight",
}
_MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
def reshard_mp(
state_dicts: List[Dict[str, torch.Tensor]],
size: int,
rank: int,
repeat_qk_qv: int = 1,
) -> Dict[str, torch.Tensor]:
"""
Reshard a list of state dicts into a single state dict given a change in MP size.
If the list has more than one state dict, we concatenate the values of the same
key across all state dicts. Otherwise, we just slice it for the current MP rank.
"""
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
if len(tensors) > 1:
return torch.cat(tensors, dim=dim)
return tensors[0].chunk(size, dim=dim)[rank].clone()
def process_key(key: str) -> torch.Tensor:
if row_regex.search(key):
return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
elif column_regex.search(key):
if "w13" in key or "fc1_weight" in key:
dims = state_dicts[0][key].size()
values = [s[key].view(2, dims[0] // 2, *dims[1:]) for s in state_dicts]
return concat_or_chunk(values, dim=1).flatten(0, 1)
elif "qkv" in key:
q_dim = state_dicts[0][key.replace("qkv", "o")].size(1)
kv_dim = (state_dicts[0][key].size(0) - q_dim) // 2
values = [s[key].split((q_dim, kv_dim, kv_dim)) for s in state_dicts]
return torch.cat([concat_or_chunk(x, dim=0) for x in zip(*values, strict=False)]) # type: ignore
elif "wk.weight" in key or "wv.weight" in key:
# Support MP > #kv_head
return concat_or_chunk([s[key].repeat(repeat_qk_qv, 1) for s in state_dicts], dim=0)
elif key == "output.bias" or key == "fc.weight":
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
elif "w_" in key:
return concat_or_chunk([s[key] for s in state_dicts], dim=-2)
else:
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
else:
return state_dicts[0][key].clone()
row_keys = _WEIGHT_ROW_KEY | _MOE_WEIGHT_ROW_KEY
column_keys = _WEIGHT_COLUMN_KEY | _MOE_WEIGHT_COLUMN_KEY
column_regex = re.compile("|".join(column_keys))
row_regex = re.compile("|".join(row_keys))
output: Dict[str, torch.Tensor] = {}
with concurrent.futures.ThreadPoolExecutor() as executor:
# Note: only processes keys in the first state dict.
# Assumes keys are the same across all state dicts.
mappings = {executor.submit(process_key, key): key for key in state_dicts[0]}
for future in concurrent.futures.as_completed(mappings):
output[mappings[future]] = future.result()
return output
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
routed_regex = re.compile("|".join(routed_keys))
keys = list(state_dict.keys())
for key in keys:
if routed_regex.search(key):
state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0)
return state_dict

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import base64
from enum import Enum
from io import BytesIO
@ -19,8 +12,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema
# The goal is that these set of types are relevant for all Llama models.
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
# the llama3 series of models.
@ -47,14 +38,7 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
# Plan is to deprecate the Dict in favor of a JSON string
# that is parsed on the client side instead of trying to manage
# the recursive type here.
# Making this a union so that client side can start prepping for this change.
# Eventually, we will remove both the Dict and arguments_json field,
# and arguments will just be a str
arguments: Union[str, Dict[str, RecursiveType]]
arguments_json: Optional[str] = None
arguments: Dict[str, RecursiveType]
@field_validator("tool_name", mode="before")
@classmethod
@ -98,6 +82,29 @@ class StopReason(Enum):
out_of_tokens = "out_of_tokens"
class ToolParamDefinition(BaseModel):
param_type: str
description: Optional[str] = None
required: Optional[bool] = True
default: Optional[Any] = None
class ToolDefinition(BaseModel):
tool_name: Union[BuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
class RawMediaItem(BaseModel):
type: Literal["image"] = "image"
data: bytes | BytesIO
@ -140,292 +147,25 @@ class RawMessage(BaseModel):
tool_calls: List[ToolCall] = Field(default_factory=list)
register_schema(ToolCall)
class GenerationResult(BaseModel):
token: int
text: str
logprobs: Optional[List[float]] = None
source: Literal["input"] | Literal["output"]
# index within the batch
batch_idx: int
# whether generation for this item is already finished. note that tokens can
# get returned even afterwards since other items in the batch can still be generating tokens
finished: bool
# because a batch is parallel processed, useful decoding for one item can correspond to processing
# pad tokens or tokens beyond EOS for other items. we could have decided to return None for this case
# but it's more convenient to return a list of GenerationResult and filter out the ignored tokens
ignore_token: bool
@json_schema_type
class ToolParamDefinition(BaseModel):
param_type: str
description: Optional[str] = None
required: Optional[bool] = True
default: Optional[Any] = None
@json_schema_type
class ToolDefinition(BaseModel):
tool_name: Union[BuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
@json_schema_type
class GreedySamplingStrategy(BaseModel):
type: Literal["greedy"] = "greedy"
@json_schema_type
class TopPSamplingStrategy(BaseModel):
type: Literal["top_p"] = "top_p"
temperature: Optional[float] = Field(..., gt=0.0)
top_p: Optional[float] = 0.95
@json_schema_type
class TopKSamplingStrategy(BaseModel):
type: Literal["top_k"] = "top_k"
top_k: int = Field(..., ge=1)
SamplingStrategy = Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
]
register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type
class SamplingParams(BaseModel):
"""Sampling parameters.
:param strategy: The sampling strategy.
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
your prompt plus max_tokens cannot exceed the model's context length.
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
:param stop: Up to 4 sequences where the API will stop generating further tokens.
The returned text will not contain the stop sequence.
"""
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
class CheckpointQuantizationFormat(Enum):
# default format
bf16 = "bf16"
# used for enabling fp8_rowwise inference, some weights are bf16
fp8_mixed = "fp8-mixed"
int8 = "int8"
int4 = "int4"
class ModelFamily(Enum):
llama2 = "llama2"
llama3 = "llama3"
llama3_1 = "llama3_1"
llama3_2 = "llama3_2"
llama3_3 = "llama3_3"
llama4 = "llama4"
safety = "safety"
class CoreModelId(Enum):
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
# Llama 2 family
llama2_7b = "Llama-2-7b"
llama2_13b = "Llama-2-13b"
llama2_70b = "Llama-2-70b"
llama2_7b_chat = "Llama-2-7b-chat"
llama2_13b_chat = "Llama-2-13b-chat"
llama2_70b_chat = "Llama-2-70b-chat"
# Llama 3 family
llama3_8b = "Llama-3-8B"
llama3_70b = "Llama-3-70B"
llama3_8b_instruct = "Llama-3-8B-Instruct"
llama3_70b_instruct = "Llama-3-70B-Instruct"
# Llama 3.1 family
llama3_1_8b = "Llama3.1-8B"
llama3_1_70b = "Llama3.1-70B"
llama3_1_405b = "Llama3.1-405B"
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
# Llama 3.2 family
llama3_2_1b = "Llama3.2-1B"
llama3_2_3b = "Llama3.2-3B"
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
llama3_2_11b_vision = "Llama3.2-11B-Vision"
llama3_2_90b_vision = "Llama3.2-90B-Vision"
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
# Llama 3.3 family
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
# Llama 4 family
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
# Safety models
llama_guard_3_8b = "Llama-Guard-3-8B"
llama_guard_2_8b = "Llama-Guard-2-8B"
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
llama_guard_3_1b = "Llama-Guard-3-1B"
def is_multimodal(model_id) -> bool:
if model_id in [
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return True
else:
return False
def model_family(model_id) -> ModelFamily:
if model_id in [
CoreModelId.llama2_7b,
CoreModelId.llama2_13b,
CoreModelId.llama2_70b,
CoreModelId.llama2_7b_chat,
CoreModelId.llama2_13b_chat,
CoreModelId.llama2_70b_chat,
]:
return ModelFamily.llama2
elif model_id in [
CoreModelId.llama3_8b,
CoreModelId.llama3_70b,
CoreModelId.llama3_8b_instruct,
CoreModelId.llama3_70b_instruct,
]:
return ModelFamily.llama3
elif model_id in [
CoreModelId.llama3_1_8b,
CoreModelId.llama3_1_70b,
CoreModelId.llama3_1_405b,
CoreModelId.llama3_1_8b_instruct,
CoreModelId.llama3_1_70b_instruct,
CoreModelId.llama3_1_405b_instruct,
]:
return ModelFamily.llama3_1
elif model_id in [
CoreModelId.llama3_2_1b,
CoreModelId.llama3_2_3b,
CoreModelId.llama3_2_1b_instruct,
CoreModelId.llama3_2_3b_instruct,
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return ModelFamily.llama3_2
elif model_id in [
CoreModelId.llama3_3_70b_instruct,
]:
return ModelFamily.llama3_3
elif model_id in [
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_scout_17b_16e_instruct,
CoreModelId.llama4_maverick_17b_128e,
CoreModelId.llama4_maverick_17b_128e_instruct,
]:
return ModelFamily.llama4
elif model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_2_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return ModelFamily.safety
else:
raise ValueError(f"Unknown model family for {model_id}")
class Model(BaseModel):
core_model_id: CoreModelId
description: str
huggingface_repo: Optional[str] = None
recommended_sampling_params: Optional[SamplingParams] = None
arch_args: Dict[str, Any]
variant: str = ""
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
pth_file_count: int
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
# silence pydantic until we remove the `model_` fields
model_config = ConfigDict(protected_namespaces=())
@property
def model_family(self) -> ModelFamily:
return model_family(self.core_model_id)
# The SKU is uniquely identified by (model_id, variant) combo
def descriptor(self, shorten_default_variant: bool = True) -> str:
if not self.variant:
return self.core_model_id.value
return f"{self.core_model_id.value}:{self.variant}"
@property
def is_instruct_model(self) -> bool:
return "instruct" in self.id.name
# Featured models are shown in the non-exhaustive model list
@property
def is_featured(self) -> bool:
return self.model_family in [
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
ModelFamily.safety,
]
@property
def max_seq_length(self) -> int:
if self.model_family == ModelFamily.llama2:
return 4096
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
return 4096
elif self.model_family == ModelFamily.llama3:
return 8192
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
return 131072
elif self.model_family == ModelFamily.llama3_2:
if self.quantization_format == CheckpointQuantizationFormat.int4:
return 8192
return 131072
elif self.model_family == ModelFamily.llama4:
if self.core_model_id in {
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_maverick_17b_128e,
}:
return 262144
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
return 10485760
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
return 1048576
elif self.core_model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return 131072
else:
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
class QuantizationMode(str, Enum):
none = "none"
fp8_mixed = "fp8_mixed"
int4_mixed = "int4_mixed"

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass
from enum import Enum
from typing import Optional

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import io
import json
import uuid

View file

@ -4,59 +4,37 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from termcolor import cprint
from ..datatypes import RawContent, RawMessage, StopReason, ToolPromptFormat
from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
from .args import ModelArgs
from .chat_format import ChatFormat, LLMInput
from .model import Transformer
from .multimodal.model import CrossAttentionTransformer
from .tokenizer import Tokenizer
@dataclass
class CompletionPrediction:
generation: str
decoded_tokens: Optional[List[str]] = None
logprobs: Optional[List[List[float]]] = None
@dataclass
class ChatPrediction:
generation: RawMessage
decoded_tokens: Optional[List[str]] = None
logprobs: Optional[List[List[float]]] = None
@dataclass
class TokenResult:
token: int
text: str
logprobs: Optional[List[float]] = None
# TODO: make this completely parallel to the llama4 generation.py file and share common code
# from llama-models also
class Llama3:
@staticmethod
def build(
@ -64,7 +42,7 @@ class Llama3:
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
tokenizer_path: Optional[str] = None,
quantization_mode: Optional[QuantizationMode] = None,
seed: int = 1,
device: str = "cuda",
):
@ -101,13 +79,9 @@ class Llama3:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert world_size == len(checkpoints), (
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
)
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
@ -116,40 +90,58 @@ class Llama3:
max_batch_size=max_batch_size,
**params,
)
if tokenizer_path:
tokenizer = Tokenizer(model_path=tokenizer_path)
else:
tokenizer = Tokenizer.get_instance()
tokenizer = Tokenizer.get_instance()
state_dict = maybe_reshard_state_dict(
ckpt_paths,
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
)
assert model_args.vocab_size == tokenizer.n_words
torch.set_default_device(device)
if device.type == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
elif device.type == "xpu":
if torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
else:
torch.set_default_dtype(torch.half)
if model_args.vision_chunk_size > 0:
from .multimodal.model import CrossAttentionTransformer
def build_model():
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
else:
model = Transformer(model_args)
return model
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.get_default_dtype())
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
from .quantization.loader import convert_to_quantized_model
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = build_model()
print("Loading state dict...")
model.load_state_dict(state_dict, strict=False)
print("Done...")
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
torch.set_default_device(device)
else:
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=True)
model.to(device)
print(f"Setting default device to {device}")
torch.set_default_device(device)
if device.type == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
elif device.type == "xpu":
if torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
model = build_model()
print("Loading state dict...")
model.load_state_dict(state_dict, strict=True)
model.to(device)
print("Done...")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)
return Llama3(model, tokenizer, model_args)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
@ -158,26 +150,30 @@ class Llama3:
@torch.inference_mode()
def generate(
self,
model_input: LLMInput,
max_gen_len: int,
model_inputs: List[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator:
) -> Generator[List[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
params = self.model.params
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input:
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens]
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
prompt_tokens = [model_input.tokens]
for inp in model_inputs:
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
prompt_tokens = [inp.tokens for inp in model_inputs]
bsz = 1
bsz = len(model_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
@ -189,18 +185,6 @@ class Llama3:
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
is_vision = not isinstance(self.model, Transformer)
if is_vision:
images = model_input.vision.images if model_input.vision is not None else []
mask = model_input.vision.mask if model_input.vision is not None else []
# the method works for bsz > 1 so add a batch dimension
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=[images],
batch_masks=[mask],
total_len=total_len,
)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
for k, t in enumerate(prompt_tokens):
@ -208,23 +192,45 @@ class Llama3:
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
prev_pos = 0
is_vision = not isinstance(self.model, Transformer)
if is_vision:
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=images,
batch_masks=mask,
total_len=total_len,
device=tokens.device,
)
eos_reached = torch.tensor([False] * bsz)
input_text_mask = tokens != pad_id
if echo:
for i, t in enumerate(model_input.tokens):
yield TokenResult(
token=t,
text=self.tokenizer.decode([t]),
logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None),
)
for i in range(max_prompt_len):
results = []
for j, t in enumerate(tokens[:, i]):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="input",
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
batch_idx=j,
finished=False,
ignore_token=t.item() == pad_id,
)
)
yield results
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
prev_pos = 0
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
text_only_inference = model_input.vision is None
text_only_inference = all(inp.vision is None for inp in model_inputs)
logits = self.model.forward(
position_ids,
tokens,
@ -271,155 +277,69 @@ class Llama3:
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
yield TokenResult(
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
)
results = []
for idx, t in enumerate(next_token):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)
yield results
prev_pos = cur_pos
if all(eos_reached):
break
def text_completion(
def completion(
self,
content: RawContent,
contents: List[RawContent],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> CompletionPrediction:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(content)
tokens = []
token_logprobs = []
decoded_tokens = []
) -> Generator[List[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
model_inputs=model_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
tokens.append(result.token)
if logprobs:
decoded_tokens.append(result.text)
token_logprobs.append(result.logprobs)
generation = self.tokenizer.decode(tokens)
if logprobs:
return CompletionPrediction(
generation=generation,
logprobs=token_logprobs,
decoded_tokens=decoded_tokens,
)
return CompletionPrediction(generation=generation)
yield result
if all(r.finished for r in result):
break
def chat_completion(
self,
messages: List[RawMessage],
messages_batch: List[List[RawMessage]],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
echo: bool = False,
) -> ChatPrediction:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
tokens = []
token_logprobs = []
decoded_tokens = []
stop_reason = None
) -> Generator[List[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate(
model_input=self.formatter.encode_dialog_prompt(messages, tool_prompt_format),
max_gen_len=max_gen_len,
model_inputs=model_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
tokens.append(result.token)
if result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
elif result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
if logprobs:
decoded_tokens.append(result.text)
token_logprobs.append(result.logprobs)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.formatter.decode_assistant_message(tokens, stop_reason)
if logprobs:
return ChatPrediction(
generation=message,
logprobs=token_logprobs,
decoded_tokens=decoded_tokens,
)
return ChatPrediction(generation=message)
def chat_completion_raw(
self,
messages: List[RawMessage],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> List[int]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
output_tokens = []
model_input = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
input_tokens = model_input.tokens
for result in self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=False,
):
output_tokens.append(result.token)
return input_tokens, output_tokens
def text_completion_raw(
self,
content: RawContent,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
):
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(content)
input_tokens = model_input.tokens
output_tokens = []
for result in self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=False,
):
output_tokens.append(result.token)
return input_tokens, output_tokens
yield result
if all(r.finished for r in result):
break
def sample_top_p(probs, p):

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import math
from typing import Optional, Tuple
@ -29,6 +19,10 @@ from torch import nn
from .args import ModelArgs
# **NOTE**: This code is not runnable without installing `torch` and `fairscale`
# dependencies. These dependencies are not part of the default dependencies
# (requirements.txt) of the `llama-models` package.
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
@ -111,9 +105,9 @@ class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
world_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // world_size
self.n_local_kv_heads = self.n_kv_heads // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import logging
import math
from functools import partial
@ -180,14 +170,14 @@ class ImageAttention(nn.Module):
n_heads,
):
super().__init__()
model_parallel_size = fs_init.get_model_parallel_world_size()
world_size = fs_init.get_model_parallel_world_size()
qkvo_replication = 1
if model_parallel_size > 16:
qkvo_replication = model_parallel_size // 8
if world_size > 16:
qkvo_replication = world_size // 8
self.n_kv_heads = n_heads
self.n_local_heads = n_heads * qkvo_replication // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // model_parallel_size
self.n_local_heads = n_heads * qkvo_replication // world_size
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = dim // n_heads
@ -536,16 +526,16 @@ class Attention(nn.Module):
cache_v (torch.Tensor): Cached values for attention.
"""
super().__init__()
model_parallel_size = fs_init.get_model_parallel_world_size()
world_size = fs_init.get_model_parallel_world_size()
replication_factor = 1
if model_parallel_size > 8:
replication_factor = model_parallel_size // MP_SCALE
if world_size > 8:
replication_factor = world_size // MP_SCALE
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_kv_heads *= replication_factor
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_local_heads = args.n_heads // world_size
self.n_local_kv_heads = self.n_kv_heads // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.max_seq_len = args.max_seq_len
@ -587,13 +577,11 @@ class Attention(nn.Module):
self.n_local_kv_heads,
self.head_dim,
)
device = next(self.parameters()).device
self.register_buffer(
"key_cache",
torch.zeros(
cache_shape,
dtype=dtype,
device=device,
),
persistent=False,
)
@ -602,7 +590,6 @@ class Attention(nn.Module):
torch.zeros(
cache_shape,
dtype=dtype,
device=device,
),
persistent=False,
)
@ -614,6 +601,9 @@ class Attention(nn.Module):
freqs_cis: torch.Tensor,
position_ids: torch.LongTensor,
):
self.key_cache = self.key_cache.to(x.device)
self.value_cache = self.value_cache.to(x.device)
xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
bs, slen, _ = xq.shape
@ -832,10 +822,10 @@ class CrossAttention(torch.nn.Module):
norm_eps: float,
):
super().__init__()
self.model_parallel_size = fs_init.get_model_parallel_world_size()
self.world_size = fs_init.get_model_parallel_world_size()
replication_factor = 1
if self.model_parallel_size > 8:
replication_factor = self.model_parallel_size // MP_SCALE
if self.world_size > 8:
replication_factor = self.world_size // MP_SCALE
n_kv_heads *= replication_factor
assert n_heads % n_kv_heads == 0
@ -889,10 +879,10 @@ class CrossAttention(torch.nn.Module):
# trunk LLM (i.e., group query attention) -- @dubeya
# local heads
assert self.n_heads % self.n_kv_heads == 0
assert self.n_heads % self.model_parallel_size == 0
assert self.n_kv_heads % self.model_parallel_size == 0
self.n_local_heads = self.n_heads // self.model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
assert self.n_heads % self.world_size == 0
assert self.n_kv_heads % self.world_size == 0
self.n_local_heads = self.n_heads // self.world_size
self.n_local_kv_heads = self.n_kv_heads // self.world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
@ -1041,7 +1031,7 @@ class CrossAttentionTransformerVision(torch.nn.Module):
self.image_res = args.vision_chunk_size
self.max_num_chunks = args.vision_max_num_chunks
if return_intermediate is not None:
return_intermediate = [int(level) for level in return_intermediate.split(",")]
return_intermediate = [int(layer) for layer in return_intermediate.split(",")]
self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
self.patch_size = 14
self.vision_encoder = VisionEncoder(
@ -1076,15 +1066,15 @@ class CrossAttentionTransformerText(torch.nn.Module):
def __init__(self, args: ModelArgs) -> None:
super().__init__()
self.model_parallel_size = fs_init.get_model_parallel_world_size()
self.world_size = fs_init.get_model_parallel_world_size()
assert args.vocab_size > 0
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
assert self.vocab_size % self.model_parallel_size == 0
self.n_local_kv_heads = self.n_kv_heads // self.world_size
assert self.vocab_size % self.world_size == 0
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
self.pos_embeddings = None
# final norm layer (not necessary for post-norm)
@ -1184,6 +1174,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
text_only_inference: bool = False,
):
assert self.cache_is_setup, "Please set up cache before calling forward"
self.mask_cache = self.mask_cache.to(h.device)
self.freqs_cis = self.freqs_cis.to(h.device)
mask = self.mask_cache.index_select(2, position_ids)
freqs_cis = self.freqs_cis.index_select(0, position_ids)
@ -1212,9 +1204,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
output = gather_from_tensor_model_parallel_region(output)
return output.float()
def setup_cache(self, max_batch_size: int, dtype=torch.bfloat16):
def setup_cache(self, max_batch_size: int, device: torch.device, dtype=torch.bfloat16):
# Set up the text kv caches
device = next(self.parameters()).device
ones = torch.ones(
(self.max_seq_len, self.max_seq_len),
dtype=torch.bool,
@ -1265,7 +1256,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
return (
cross_attention_masks.to(device=text_device, dtype=text_dtype),
full_text_row_masked_out_mask,
full_text_row_masked_out_mask.to(device=text_device),
)
@ -1284,14 +1275,15 @@ class CrossAttentionTransformer(torch.nn.Module):
max_num_chunks=args.vision_max_num_chunks,
)
def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
self.text_model.setup_cache(max_batch_size, dtype)
def setup_cache(self, max_batch_size: int, device: torch.device, dtype: torch.dtype):
self.text_model.setup_cache(max_batch_size, device, dtype)
def compute_vision_tokens_masks(
self,
batch_images: List[List[PIL_Image.Image]],
batch_masks: List[List[List[int]]],
total_len: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
skip_vision_encoder = False
@ -1318,6 +1310,7 @@ class CrossAttentionTransformer(torch.nn.Module):
image_res=self.params.vision_chunk_size,
max_num_images=max_num_images,
)
stacked_images = stacked_images.to(device=device)
if skip_vision_encoder:
vision_tokens = torch.zeros(
@ -1330,7 +1323,7 @@ class CrossAttentionTransformer(torch.nn.Module):
),
)
else:
vision_tokens = self.vision_model(stacked_images, aspect_ratios)
vision_tokens = self.vision_model(stacked_images, aspect_ratios).to(device=device)
bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
xattn_caches = torch.stack(

View file

@ -15,7 +15,7 @@ import textwrap
from datetime import datetime
from typing import Any, List, Optional
from llama_stack.models.llama.datatypes import (
from llama_stack.apis.inference import (
BuiltinTool,
ToolDefinition,
ToolParamDefinition,

View file

@ -4,9 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
# type: ignore
import os
from typing import Any, Dict, List, Optional, cast
@ -18,22 +15,15 @@ from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_regi
from torch import Tensor, nn
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import resolve_model
from ...config import MetaReferenceQuantizedInferenceConfig
from ...datatypes import CheckpointQuantizationFormat
from ...datatypes import QuantizationMode
from ...quantize_impls import (
Fp8ScaledWeights,
ffn_swiglu,
load_fp8,
quantize_fp8,
)
from ..args import ModelArgs
from ..model import Transformer, TransformerBlock
log = get_logger(__name__, category="quantization")
from ..multimodal.model import CrossAttentionTransformer
def swiglu_wrapper(
@ -44,30 +34,34 @@ def swiglu_wrapper(
return reduce_from_model_parallel_region(out)
def convert_to_quantized_model(
model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str,
quantization_mode: Optional[str] = None,
fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
) -> Transformer | CrossAttentionTransformer:
if quantization_mode == QuantizationMode.fp8_mixed:
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
elif quantization_mode == QuantizationMode.int4_mixed:
return convert_to_int4_quantized_model(model, checkpoint_dir, device)
else:
raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
def convert_to_fp8_quantized_model(
model: Transformer,
config: MetaReferenceQuantizedInferenceConfig,
checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
return model
elif config.quantization.type != QuantizationType.fp8.value:
raise ValueError("Only FP8 quantization is supported")
assert config.model is not None, "Model must be specified for quantized inference"
llama_model = resolve_model(config.model)
assert llama_model is not None, f"Model {config.model} not found"
# Move weights to GPU with quantization
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
log.info("Loading fp8 scales...")
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
if os.path.isfile(fp8_scales_path):
print("Loading fp8 scales...")
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
for _, block in model.named_modules():
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
@ -81,8 +75,8 @@ def convert_to_fp8_quantized_model(
fp8_activation_scale_ub,
)
else:
log.info("Quantizing fp8 weights from bf16...")
for block in model.layers:
print("Quantizing fp8 weights from bf16...")
for _, block in model.named_modules():
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
@ -92,12 +86,12 @@ def convert_to_fp8_quantized_model(
param.weight = quantize_fp8(
param.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
output_device=device,
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda")
parameter.data = parameter.to(device=device)
return model
@ -290,11 +284,12 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
def convert_to_int4_quantized_model(
model: Transformer,
model_args: ModelArgs,
) -> Transformer:
model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str,
device: Optional[torch.device] = None,
) -> Transformer | CrossAttentionTransformer:
"""Convert the model to int4 quantized model."""
model_args = model.params
assert model_args.quantization_args is not None, "Quantization args must be specified."
quantization_args = model_args.quantization_args
if quantization_args.scheme is None:
@ -318,5 +313,4 @@ def convert_to_int4_quantized_model(
lora_scale = model_args.lora_args.scale
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
return cast(Transformer, model.to(device))
return cast(Transformer | CrossAttentionTransformer, model.to(device=device))

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from logging import getLogger
from pathlib import Path

View file

@ -3,10 +3,3 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

View file

@ -4,12 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import json
import textwrap

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap
from pathlib import Path

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from enum import Enum
from typing import Optional

View file

@ -12,6 +12,7 @@ from typing import Dict, List, Optional, Tuple
import torch
from PIL import Image as PIL_Image
# TODO: either fork these or move them to the common package
from ..datatypes import (
BuiltinTool,
RawContent,
@ -26,10 +27,7 @@ from ..datatypes import (
from ..llama3.tool_utils import ToolUtils
from .args import VisionArgs
from .datatypes import LLMInput
from .preprocess import (
ResizeNormalizeImageTransform,
VariableSizeImageTransform,
)
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
from .tokenizer import Tokenizer
@ -50,7 +48,7 @@ class TransformedImage:
aspect_ratio: Tuple[int, int]
def convert_rgba_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
if image.mode == "RGBA":
image.load() # for png.split()
new_img = PIL_Image.new("RGB", image.size, bg)
@ -167,7 +165,7 @@ class ChatFormat:
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
image = PIL_Image.open(bytes_io)
image = convert_rgba_to_rgb(image)
image = convert_image_to_rgb(image)
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
if image_tiles.shape[0] > 1:
@ -212,12 +210,9 @@ class ChatFormat:
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
_process_content(content)
# Tool calls and Tool Response messages should be eom
eom = False
if message.role == "assistant":
eom = message.stop_reason == StopReason.end_of_message or message.tool_calls
elif message.role == "tool":
eom = True
eom = message.stop_reason == StopReason.end_of_message
tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"])
return tokens, images
@ -252,11 +247,6 @@ class ChatFormat:
if content.startswith(header_str):
content = content[len(header_str) :]
ipython = content.startswith("<|python_start|>")
if ipython:
content = content[len("<|python_start|>") :]
content = content.replace("<|python_end|>", "")
if content.endswith("<|eot|>"):
content = content[: -len("<|eot|>")]
stop_reason = StopReason.end_of_turn
@ -287,11 +277,6 @@ class ChatFormat:
}
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
elif ipython:
tool_name = BuiltinTool.code_interpreter
tool_arguments = {
"code": content,
}
tool_calls = []
if tool_name is not None and tool_arguments is not None:

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass
from typing import List, Optional, Union

View file

@ -4,32 +4,43 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import codecs
import io
import json
import os
import sys
import time
from enum import Enum
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from termcolor import cprint
from ..common import TokenResult
from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode
from .args import ModelArgs
from .chat_format import (
ChatFormat,
RawContent,
RawMessage,
)
from .chat_format import ChatFormat, RawContent, RawMessage
from .datatypes import LLMInput, MaskedEmbedding, TransformerInput
from .model import Transformer
from .tokenizer import Tokenizer
@ -37,12 +48,6 @@ from .tokenizer import Tokenizer
torch.serialization.add_safe_globals([io.BytesIO, codecs.encode])
class QuantizationMode(str, Enum):
none = "none"
fp8_mixed = "fp8_mixed"
int4_mixed = "int4_mixed"
class Llama4:
@staticmethod
def build(
@ -50,7 +55,7 @@ class Llama4:
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
quantization_mode: Optional[str] = None,
quantization_mode: Optional[QuantizationMode] = None,
seed: int = 1,
):
if not torch.distributed.is_initialized():
@ -71,11 +76,9 @@ class Llama4:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert world_size == len(checkpoints), (
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
)
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
@ -92,10 +95,11 @@ class Llama4:
assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch"
print("Model args:\n", model_args.model_dump_json(indent=2))
ckpt_path = checkpoints[get_model_parallel_rank()]
print(f"Loading checkpoint from {ckpt_dir}...")
with open(ckpt_path, "rb") as f:
checkpoint = torch.load(f, map_location="cpu", weights_only=True)
state_dict = maybe_reshard_state_dict(
ckpt_paths,
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
moe_num_experts=model_args.moe_args.num_experts,
)
print("Loaded checkpoint")
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
from .quantization.loader import convert_to_quantized_model
@ -103,9 +107,9 @@ class Llama4:
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
print("Loading state dict...")
model.load_state_dict(checkpoint, strict=False)
model.load_state_dict(state_dict, strict=False)
print("Done...")
model = convert_to_quantized_model(model, ckpt_dir)
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
@ -114,7 +118,7 @@ class Llama4:
model = Transformer(model_args)
print("Loading state dict...")
model.load_state_dict(checkpoint, strict=False)
model.load_state_dict(state_dict, strict=False)
print("Done...")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
@ -129,7 +133,7 @@ class Llama4:
@torch.inference_mode()
def generate(
self,
llm_input: LLMInput,
llm_inputs: List[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
@ -137,22 +141,20 @@ class Llama4:
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator:
) -> Generator[List[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
max_gen_len = self.model.args.max_seq_len - 1
params = self.model.args
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input and get_model_parallel_rank() == 0:
tokens_to_print = list(llm_input.tokens)
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
prompt_tokens = [llm_input.tokens]
if print_model_input:
cprint("Input to model:\n", "yellow")
for inp in llm_inputs:
cprint(self.tokenizer.decode(inp.tokens.tolist()), "grey")
prompt_tokens = [inp.tokens for inp in llm_inputs]
bsz = 1
bsz = len(llm_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
@ -175,24 +177,33 @@ class Llama4:
input_text_mask = tokens != pad_id
if echo:
for i, t in enumerate(llm_input.tokens):
yield TokenResult(
token=t,
text=self.tokenizer.decode([t]),
logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None),
)
for i in range(max_prompt_len):
results = []
for j, t in enumerate(tokens[:, i]):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="input",
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
batch_idx=j,
finished=False,
ignore_token=t.item() == pad_id,
)
)
yield results
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
prev_pos = 0
for cur_pos in range(min_prompt_len, total_len):
image_embedding = None
if prev_pos == 0 and llm_input.images is not None and len(llm_input.images) > 0:
if prev_pos == 0 and any(inp.images is not None and len(inp.images) > 0 for inp in llm_inputs):
image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"]
image_mask = image_mask.unsqueeze(-1)
h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos])
image_batch = [llm_input.images]
image_batch = [inp.images if inp.images is not None else [] for inp in llm_inputs]
image_embedding = MaskedEmbedding(
embedding=self.model.vision_embeddings(image_batch, image_mask, h),
mask=image_mask,
@ -228,11 +239,21 @@ class Llama4:
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
yield TokenResult(
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
)
results = []
for idx, t in enumerate(next_token):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)
yield results
prev_pos = cur_pos
if all(eos_reached):
@ -240,68 +261,47 @@ class Llama4:
def completion(
self,
content: RawContent,
contents: List[RawContent],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator:
llm_input = self.formatter.encode_content(content)
) -> Generator[List[GenerationResult], None, None]:
llm_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate(
llm_input=llm_input,
llm_inputs=llm_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
if result.token in self.tokenizer.stop_tokens:
break
yield result
if all(r.finished for r in result):
break
def chat_completion(
self,
messages: List[RawMessage],
messages_batch: List[List[RawMessage]],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator:
llm_input = self.formatter.encode_dialog_prompt(messages)
) -> Generator[List[GenerationResult], None, None]:
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate(
llm_input=llm_input,
llm_inputs=llm_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
if result.token in self.tokenizer.stop_tokens:
break
yield result
def chat_completion_raw(
self,
messages: List[RawMessage],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
):
llm_input = self.formatter.encode_dialog_prompt(messages)
output_tokens = []
for result in self.generate(
llm_input=llm_input,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
):
output_tokens.append(result.token)
return llm_input.tokens, output_tokens
if all(r.finished for r in result):
break
def sample_top_p(probs, p):

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import math
from typing import Any, Dict, List, Optional, Tuple
@ -184,7 +174,6 @@ class Attention(nn.Module):
self.head_dim,
)
).cuda()
self.qk_norm = None
if self.use_qk_norm:
self.qk_norm = L2Norm(args.norm_eps)

View file

@ -100,31 +100,21 @@ class Experts(nn.Module):
class MoE(torch.nn.Module):
"""
This EC implementation is modified from the original EC module.
We refactored the token permutation and unpermutation logic and added support to tp and dp2ep sharding.
This module supports 3 sharding methods of the experts:
- tp: each TP rank has n_experts experts. Experts are sharded following the conventional row/column-parallel TP sharding.
- tp2ep: each TP rank has n_experts/tp experts. Experts are not sharded.
- dp2ep: each EP rank has n_experts/ep experts. Experts are sharded following the row/column-parallel TP sharding.
Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
Several commonly used annotations include:
- a: bsz*slen
- E: number of experts
- e: number of local experts per ep (n_experts/ep)
- et: number of local experts per tp (n_experts/tp)
- D: hidden dimension
- d: D/tp
- F: model dimension
- f: F/tp (used in column/row-parallel linear)
- G: number of tokens per expert (a * capacity_factor / E)
- g: number of tokens per expert per TP rank (i.e., G/TP)
- GG: G*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=False)
- gg: g*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=True)
Examples:
x_aD [a, D]
routed_in_etG_D [et*G, D]
x_eGGD: [e, GG, D]
x_eGD: [e, G, D]
"""
def __init__(
@ -207,13 +197,13 @@ class MoE(torch.nn.Module):
routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
out_aD = self.shared_expert(x_aD)
routed_out_egg_D = self.experts(routed_in_EG_D.detach())
routed_out_eg_D = self.experts(routed_in_EG_D.detach())
router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
out_aD.scatter_add_(
dim=0,
index=router_indices_EG_D,
src=routed_out_egg_D.view(-1, D),
src=routed_out_eg_D.view(-1, D),
)
out_aD = reduce_from_model_parallel_region(out_aD)
return out_aD.view(-1, slen, D)

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap
from io import BytesIO
from pathlib import Path

View file

@ -6,20 +6,29 @@
import logging
import os
from typing import Optional
from typing import Callable, Optional
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
from torch import Tensor
from torch import Tensor, nn
from torch.nn import functional as F
from ..generation import QuantizationMode
from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock
from ..moe import MoE
log = logging.getLogger(__name__)
def swiglu_wrapper_no_reduce(
self,
x: Tensor,
):
from ...quantize_impls import ffn_swiglu
return ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
def experts_batched_swiglu_wrapper(
self,
x: Tensor, # (e, g, D)
@ -51,24 +60,30 @@ def convert_to_quantized_model(
rank = get_model_parallel_rank()
def should_quantize_block(block: nn.Module) -> bool:
if not isinstance(block, TransformerBlock):
return False
is_moe = isinstance(block.feed_forward, MoE)
if quantization_mode == QuantizationMode.fp8_mixed:
# skip quantization on first and last layers
return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
return is_moe
use_rich_progress = use_rich_progress and rank == 0
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model)
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block)
if quantization_mode == QuantizationMode.int4_mixed:
int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
int4_zero_points_path = os.path.join(checkpoint_dir, f"int4_zero_points_{rank}.pt")
if os.path.isfile(int4_scales_path):
log_status(f"Rank {rank}: Loading int4 scales")
int4_scales = torch.load(int4_scales_path, weights_only=True)
int4_zero_points = torch.load(int4_zero_points_path, weights_only=True)
def apply_quantization(key, weight):
scale = int4_scales[key]
zero_point = int4_zero_points[key]
return load_int4(
weight,
scale,
zero_point,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
@ -76,7 +91,8 @@ def convert_to_quantized_model(
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
def apply_quantization(_, weight):
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
return quantize_int4(weight, output_device=torch.device("cuda"))
else:
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
if os.path.isfile(fp8_scales_path):
@ -104,33 +120,38 @@ def convert_to_quantized_model(
progress.start()
for _, block in model.named_modules():
if isinstance(block, TransformerBlock):
# Skip quantization on first and last layers
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
if not should_quantize_block(block):
continue
# Skip quantization on dense layers
if not isinstance(block.feed_forward, MoE):
continue
update_status(f"Rank {rank} - Layer {block.layer_id}")
update_status(f"Rank {rank} - Layer {block.layer_id}")
# Quantize only routed experts, not shared
prefix = f"layers.{block.layer_id}.feed_forward"
moe = block.feed_forward
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
# Quantize only routed experts, not shared
prefix = f"layers.{block.layer_id}.feed_forward"
moe = block.feed_forward
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
for key in ("w1", "w3", "w2"):
param = getattr(moe.experts, key)
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
setattr(
moe.experts,
key,
apply_quantization(
f"{prefix}.experts.{key}",
param.transpose(1, 2).contiguous(),
),
)
if quantization_mode == QuantizationMode.int4_mixed:
# Quantize shared experts
moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert)
for key in ("w1", "w3", "w2"):
param = getattr(moe.experts, key)
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
setattr(
moe.experts,
key,
apply_quantization(f"{prefix}.experts.{key}", param.transpose(1, 2).contiguous()),
)
param = getattr(moe.shared_expert, key)
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}")
param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight)
processed_blocks += 1
update_status(message=None, completed=processed_blocks)
processed_blocks += 1
update_status(message=None, completed=processed_blocks)
update_status(f"Rank {rank} - Moving parameters to CUDA")
@ -149,7 +170,12 @@ def convert_to_quantized_model(
# fp8/int4 loading can be very slow so we add progress bars to make life slightly better
def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer):
def logging_callbacks(
use_rich_progress: bool,
rank: int,
model: Transformer,
should_quantize_block: Callable[[nn.Module], bool],
):
console = None
if use_rich_progress:
from rich.console import Console
@ -162,15 +188,7 @@ def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer):
elif rank == 0: # Only log from rank 0 for non-rich logging
log.info(message)
total_blocks = sum(
1
for _, block in model.named_modules()
if (
isinstance(block, TransformerBlock)
and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
and isinstance(block.feed_forward, MoE)
)
)
total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
progress = None
if use_rich_progress:
from rich.progress import (

View file

@ -4,6 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
@ -59,8 +66,6 @@ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
"<|text_post_train_reserved_special_token_3|>",
"<|text_post_train_reserved_special_token_4|>",
"<|text_post_train_reserved_special_token_5|>",
"<|python_start|>",
"<|python_end|>",
"<|finetune_right_pad|>",
] + get_reserved_special_tokens(
"text_post_train", 61, 6
@ -85,8 +90,23 @@ LLAMA4_VISION_SPECIAL_TOKENS = [
"vision", 1041, 7
) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|>
# 201134, ..., 201143
LLAMA4_REASONING_SPECIAL_TOKENS = [
"<|reasoning_reserved_special_token_0|>",
"<|reasoning_reserved_special_token_1|>",
"<|reasoning_reserved_special_token_2|>",
"<|reasoning_reserved_special_token_3|>",
"<|reasoning_reserved_special_token_4|>",
"<|reasoning_reserved_special_token_5|>",
"<|reasoning_reserved_special_token_6|>",
"<|reasoning_reserved_special_token_7|>",
"<|reasoning_thinking_start|>",
"<|reasoning_thinking_end|>",
]
LLAMA4_SPECIAL_TOKENS = LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS
LLAMA4_SPECIAL_TOKENS = (
LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS + LLAMA4_REASONING_SPECIAL_TOKENS
)
BASIC_SPECIAL_TOKENS = [
"<|begin_of_text|>",
@ -155,6 +175,9 @@ class Tokenizer:
self.eot_id: int = self.special_tokens["<|eot|>"]
self.eom_id: int = self.special_tokens["<|eom|>"]
self.thinking_start_id: int = self.special_tokens["<|reasoning_thinking_start|>"]
self.thinking_end_id: int = self.special_tokens["<|reasoning_thinking_end|>"]
self.stop_tokens = [
self.eos_id,
self.special_tokens["<|eom|>"],

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import math
from typing import Any, Callable, Dict, List

View file

@ -28,9 +28,6 @@ from llama_stack.models.llama.datatypes import (
ToolPromptFormat,
)
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
LLMInput,
)
from .llama3.interface import LLama31Interface
from .llama3.template_data import (
@ -38,6 +35,7 @@ from .llama3.template_data import (
system_message_builtin_tools_only,
system_message_custom_tools_only,
)
from .llama4.datatypes import LLMInput
class TextCompletionContent(BaseModel):

View file

@ -4,24 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional
from .datatypes import (
from .sku_types import (
CheckpointQuantizationFormat,
CoreModelId,
Model,
ModelFamily,
SamplingParams,
TopPSamplingStrategy,
)
LLAMA2_VOCAB_SIZE = 32000
@ -47,15 +38,6 @@ def all_registered_models() -> List[Model]:
)
def recommended_sampling_params() -> SamplingParams:
return SamplingParams(
strategy=TopPSamplingStrategy(
temperature=1.0,
top_p=0.9,
)
)
def llama2_family() -> List[Model]:
return [
*llama2_base_models(),
@ -150,7 +132,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_7b,
description="Llama 2 7b model",
huggingface_repo="meta-llama/Llama-2-7b",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -169,7 +150,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_13b,
description="Llama 2 13b model",
huggingface_repo="meta-llama/Llama-2-13b",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 5120,
"n_layers": 40,
@ -188,7 +168,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_70b,
description="Llama 2 70b model",
huggingface_repo="meta-llama/Llama-2-70b",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -230,7 +209,6 @@ def llama3_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_70b,
description="Llama 3 70b model",
huggingface_repo="meta-llama/Llama-3-70B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -254,7 +232,6 @@ def llama3_1_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_8b,
description="Llama 3.1 8b model",
huggingface_repo="meta-llama/Llama-3.1-8B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -273,7 +250,6 @@ def llama3_1_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_70b,
description="Llama 3.1 70b model",
huggingface_repo="meta-llama/Llama-3.1-70B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -293,7 +269,6 @@ def llama3_1_base_models() -> List[Model]:
variant="bf16-mp8",
description="Llama 3.1 405b model (BF16 weights)",
huggingface_repo="meta-llama/Llama-3.1-405B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -313,7 +288,6 @@ def llama3_1_base_models() -> List[Model]:
description="Llama 3.1 405b model (FP8 quantized)",
huggingface_repo="meta-llama/Llama-3.1-405B-FP8",
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -333,7 +307,6 @@ def llama3_1_base_models() -> List[Model]:
variant="bf16-mp16",
description="Llama 3.1 405b model (BF16 weights for mp16)",
huggingface_repo="meta-llama/Llama-3.1-405B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -357,7 +330,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_1b,
description="Llama 3.2 1b model",
huggingface_repo="meta-llama/Llama-3.2-1B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 2048,
"n_layers": 16,
@ -376,7 +348,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_3b,
description="Llama 3.2 3b model",
huggingface_repo="meta-llama/Llama-3.2-3B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 3072,
"n_layers": 28,
@ -395,7 +366,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_11b_vision,
description="Llama 3.2 11b vision model",
huggingface_repo="meta-llama/Llama-3.2-11B-Vision",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -417,7 +387,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_90b_vision,
description="Llama 3.2 90b vision model",
huggingface_repo="meta-llama/Llama-3.2-90B-Vision",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -444,7 +413,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_7b_chat,
description="Llama 2 7b chat model",
huggingface_repo="meta-llama/Llama-2-7b-chat",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -463,7 +431,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_13b_chat,
description="Llama 2 13b chat model",
huggingface_repo="meta-llama/Llama-2-13b-chat",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 5120,
"n_layers": 40,
@ -482,7 +449,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_70b_chat,
description="Llama 2 70b chat model",
huggingface_repo="meta-llama/Llama-2-70b-chat",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -506,7 +472,6 @@ def llama3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_8b_instruct,
description="Llama 3 8b instruct model",
huggingface_repo="meta-llama/Llama-3-8B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -525,7 +490,6 @@ def llama3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_70b_instruct,
description="Llama 3 70b instruct model",
huggingface_repo="meta-llama/Llama-3-70B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -549,7 +513,6 @@ def llama3_1_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_8b_instruct,
description="Llama 3.1 8b instruct model",
huggingface_repo="meta-llama/Llama-3.1-8B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -568,7 +531,6 @@ def llama3_1_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_70b_instruct,
description="Llama 3.1 70b instruct model",
huggingface_repo="meta-llama/Llama-3.1-70B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -588,7 +550,6 @@ def llama3_1_instruct_models() -> List[Model]:
variant="bf16-mp8",
description="Llama 3.1 405b instruct model (BF16 weights)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -608,7 +569,6 @@ def llama3_1_instruct_models() -> List[Model]:
description="Llama 3.1 405b instruct model (FP8 quantized)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8",
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -628,7 +588,6 @@ def llama3_1_instruct_models() -> List[Model]:
variant="bf16-mp16",
description="Llama 3.1 405b instruct model (BF16 weights for mp16)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 16384,
"n_layers": 126,
@ -684,7 +643,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 1b INT4 quantized LoRA",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
**arch_args_1b(),
"quantization_args": {
@ -703,7 +661,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 1b INT4 quantized SpinQuant",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
**arch_args_1b(),
"quantization_args": {
@ -718,7 +675,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 3b INT4 quantized LoRA",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
**arch_args_3b(),
"quantization_args": {
@ -737,7 +693,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 3b INT4 quantized SpinQuant",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
**arch_args_3b(),
"quantization_args": {
@ -755,7 +710,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_1b_instruct,
description="Llama 3.2 1b instruct model",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args=arch_args_1b(),
pth_file_count=1,
),
@ -763,7 +717,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_3b_instruct,
description="Llama 3.2 3b instruct model",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args=arch_args_3b(),
pth_file_count=1,
),
@ -772,7 +725,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_11b_vision_instruct,
description="Llama 3.2 11b vision instruct model",
huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -794,7 +746,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_90b_vision_instruct,
description="Llama 3.2 90b vision instruct model",
huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -821,7 +772,6 @@ def llama3_3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_3_70b_instruct,
description="Llama 3.3 70b instruct",
huggingface_repo="meta-llama/Llama-3.3-70B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 8192,
"n_layers": 80,
@ -846,7 +796,6 @@ def safety_models() -> List[Model]:
core_model_id=CoreModelId.llama_guard_3_11b_vision,
description="Llama Guard v3 11b vision system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 4096,
"n_layers": 32,
@ -870,7 +819,6 @@ def safety_models() -> List[Model]:
description="Llama Guard v3 1b 'int4' quantized system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4",
quantization_format=CheckpointQuantizationFormat.int4,
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 2048,
"n_layers": 12,
@ -888,7 +836,6 @@ def safety_models() -> List[Model]:
core_model_id=CoreModelId.llama_guard_3_1b,
description="Llama Guard v3 1b system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-1B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={
"dim": 2048,
"n_layers": 16,

View file

@ -0,0 +1,229 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict, Field
class CheckpointQuantizationFormat(Enum):
# default format
bf16 = "bf16"
# used for enabling fp8_rowwise inference, some weights are bf16
fp8_mixed = "fp8-mixed"
int8 = "int8"
int4 = "int4"
class ModelFamily(Enum):
llama2 = "llama2"
llama3 = "llama3"
llama3_1 = "llama3_1"
llama3_2 = "llama3_2"
llama3_3 = "llama3_3"
llama4 = "llama4"
safety = "safety"
class CoreModelId(Enum):
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
# Llama 2 family
llama2_7b = "Llama-2-7b"
llama2_13b = "Llama-2-13b"
llama2_70b = "Llama-2-70b"
llama2_7b_chat = "Llama-2-7b-chat"
llama2_13b_chat = "Llama-2-13b-chat"
llama2_70b_chat = "Llama-2-70b-chat"
# Llama 3 family
llama3_8b = "Llama-3-8B"
llama3_70b = "Llama-3-70B"
llama3_8b_instruct = "Llama-3-8B-Instruct"
llama3_70b_instruct = "Llama-3-70B-Instruct"
# Llama 3.1 family
llama3_1_8b = "Llama3.1-8B"
llama3_1_70b = "Llama3.1-70B"
llama3_1_405b = "Llama3.1-405B"
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
# Llama 3.2 family
llama3_2_1b = "Llama3.2-1B"
llama3_2_3b = "Llama3.2-3B"
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
llama3_2_11b_vision = "Llama3.2-11B-Vision"
llama3_2_90b_vision = "Llama3.2-90B-Vision"
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
# Llama 3.3 family
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
# Llama 4 family
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
# Safety models
llama_guard_3_8b = "Llama-Guard-3-8B"
llama_guard_2_8b = "Llama-Guard-2-8B"
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
llama_guard_3_1b = "Llama-Guard-3-1B"
def is_multimodal(model_id) -> bool:
if model_id in [
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return True
else:
return False
def model_family(model_id) -> ModelFamily:
if model_id in [
CoreModelId.llama2_7b,
CoreModelId.llama2_13b,
CoreModelId.llama2_70b,
CoreModelId.llama2_7b_chat,
CoreModelId.llama2_13b_chat,
CoreModelId.llama2_70b_chat,
]:
return ModelFamily.llama2
elif model_id in [
CoreModelId.llama3_8b,
CoreModelId.llama3_70b,
CoreModelId.llama3_8b_instruct,
CoreModelId.llama3_70b_instruct,
]:
return ModelFamily.llama3
elif model_id in [
CoreModelId.llama3_1_8b,
CoreModelId.llama3_1_70b,
CoreModelId.llama3_1_405b,
CoreModelId.llama3_1_8b_instruct,
CoreModelId.llama3_1_70b_instruct,
CoreModelId.llama3_1_405b_instruct,
]:
return ModelFamily.llama3_1
elif model_id in [
CoreModelId.llama3_2_1b,
CoreModelId.llama3_2_3b,
CoreModelId.llama3_2_1b_instruct,
CoreModelId.llama3_2_3b_instruct,
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return ModelFamily.llama3_2
elif model_id in [
CoreModelId.llama3_3_70b_instruct,
]:
return ModelFamily.llama3_3
elif model_id in [
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_scout_17b_16e_instruct,
CoreModelId.llama4_maverick_17b_128e,
CoreModelId.llama4_maverick_17b_128e_instruct,
]:
return ModelFamily.llama4
elif model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_2_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return ModelFamily.safety
else:
raise ValueError(f"Unknown model family for {model_id}")
class Model(BaseModel):
core_model_id: CoreModelId
description: str
huggingface_repo: Optional[str] = None
arch_args: Dict[str, Any]
variant: str = ""
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
pth_file_count: int
metadata: Dict[str, Any] = Field(default_factory=dict)
# silence pydantic until we remove the `model_` fields
model_config = ConfigDict(protected_namespaces=())
@property
def model_family(self) -> ModelFamily:
return model_family(self.core_model_id)
# The SKU is uniquely identified by (model_id, variant) combo
def descriptor(self, shorten_default_variant: bool = True) -> str:
if not self.variant:
return self.core_model_id.value
return f"{self.core_model_id.value}:{self.variant}"
@property
def is_instruct_model(self) -> bool:
return "instruct" in self.core_model_id.value
# Featured models are shown in the non-exhaustive model list
@property
def is_featured(self) -> bool:
return self.model_family in [
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
ModelFamily.safety,
]
@property
def max_seq_length(self) -> int:
if self.model_family == ModelFamily.llama2:
return 4096
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
return 4096
elif self.model_family == ModelFamily.llama3:
return 8192
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
return 131072
elif self.model_family == ModelFamily.llama3_2:
if self.quantization_format == CheckpointQuantizationFormat.int4:
return 8192
return 131072
elif self.model_family == ModelFamily.llama4:
if self.core_model_id in {
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_maverick_17b_128e,
}:
return 262144
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
return 10485760
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
return 1048576
raise AssertionError(f"Unexpected core model id: {self.core_model_id}")
elif self.core_model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return 131072
else:
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")

View file

@ -52,6 +52,7 @@ from llama_stack.apis.inference import (
StopReason,
SystemMessage,
ToolDefinition,
ToolParamDefinition,
ToolResponse,
ToolResponseMessage,
UserMessage,
@ -63,7 +64,6 @@ from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolCall,
ToolParamDefinition,
)
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.telemetry import tracing

View file

@ -12,20 +12,19 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
from llama_stack.apis.inference import (
Fp8QuantizationConfig,
GreedySamplingStrategy,
Int4QuantizationConfig,
JsonSchemaResponseFormat,
ResponseFormat,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
Model,
SamplingParams,
TopPSamplingStrategy,
)
from llama_stack.models.llama.datatypes import QuantizationMode
from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_types import Model
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
@ -136,9 +135,9 @@ class Llama4Generator:
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
quantization_mode = "fp8_mixed"
quantization_mode = QuantizationMode.fp8_mixed
elif isinstance(config.quantization, Int4QuantizationConfig):
quantization_mode = "int4_mixed"
quantization_mode = QuantizationMode.int4_mixed
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
@ -225,9 +224,9 @@ class Llama3Generator:
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
quantization_mode = "fp8_mixed"
quantization_mode = QuantizationMode.fp8_mixed
elif isinstance(config.quantization, Int4QuantizationConfig):
quantization_mode = "int4_mixed"
quantization_mode = QuantizationMode.int4_mixed
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
@ -240,6 +239,9 @@ class Llama3Generator:
world_size=llama_model.pth_file_count,
quantization_mode=quantization_mode,
)
self.tokenizer = self.inner_generator.tokenizer
self.args = self.inner_generator.args
self.formatter = self.inner_generator.formatter
def completion(
self,

View file

@ -31,23 +31,21 @@ from llama_stack.apis.inference import (
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
StopReason,
TokenLogProbs,
ToolChoice,
ToolConfig,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import (
ModelFamily,
SamplingParams,
StopReason,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,

View file

@ -14,9 +14,10 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
Message,
ToolChoice,
ToolDefinition,
UserMessage,
)
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,

View file

@ -46,6 +46,8 @@ from llama_stack.apis.inference import (
TokenLogProbs,
ToolChoice,
ToolConfig,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
@ -55,8 +57,6 @@ from llama_stack.models.llama.datatypes import (
ToolCall,
ToolDefinition,
ToolPromptFormat,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer

View file

@ -22,8 +22,8 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform
from llama_stack.apis.post_training import DatasetFormat
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import Model
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]

View file

@ -23,7 +23,8 @@ from llama_stack.apis.safety import (
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.datatypes import Api
from llama_stack.models.llama.datatypes import CoreModelId, Role
from llama_stack.models.llama.datatypes import Role
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)

View file

@ -28,8 +28,8 @@ from llama_stack.apis.inference import (
ToolConfig,
ToolDefinition,
ToolPromptFormat,
TopKSamplingStrategy,
)
from llama_stack.models.llama.datatypes import TopKSamplingStrategy
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)

View file

@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,

View file

@ -29,15 +29,13 @@ from llama_stack.apis.inference import (
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
)
from llama_stack.models.llama.datatypes import (
SamplingParams,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.datatypes import ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)

View file

@ -19,11 +19,9 @@ from llama_stack.apis.inference import (
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
GreedySamplingStrategy,
JsonSchemaResponseFormat,
TokenLogProbs,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)

View file

@ -21,6 +21,7 @@ from llama_stack.apis.inference import (
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
GreedySamplingStrategy,
Inference,
LogProbConfig,
Message,
@ -35,12 +36,9 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
ToolResponseMessage,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy,
TopPSamplingStrategy,
UserMessage,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,

View file

@ -96,7 +96,6 @@ def _convert_to_vllm_tool_calls_in_response(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
)
for call in tool_calls
]
@ -176,7 +175,6 @@ async def _process_vllm_chat_completion_stream_response(
call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name,
arguments=args,
arguments_json=args_str,
),
parse_status=ToolCallParseStatus.succeeded,
),

View file

@ -6,7 +6,7 @@
from typing import List
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,

View file

@ -12,8 +12,8 @@ import pytest
from pytest import ExitCode
from pytest_html.basereport import _process_outcome
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.models.llama.sku_types import CoreModelId
INFERENCE_APIS = ["chat_completion"]
FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"]

View file

@ -6,8 +6,8 @@
from typing import List
from llama_stack.models.llama.datatypes import * # noqa: F403
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.models.llama.sku_types import * # noqa: F403
def is_supported_safety_model(model: Model) -> bool:

View file

@ -73,21 +73,21 @@ from llama_stack.apis.inference import (
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
GreedySamplingStrategy,
Message,
SamplingParams,
SystemMessage,
TokenLogProbs,
ToolResponseMessage,
TopKSamplingStrategy,
TopPSamplingStrategy,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
BuiltinTool,
GreedySamplingStrategy,
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url,

View file

@ -34,7 +34,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
ModelFamily,
RawContent,
RawContentItem,
RawMediaItem,
@ -43,7 +42,6 @@ from llama_stack.models.llama.datatypes import (
Role,
StopReason,
ToolPromptFormat,
is_multimodal,
)
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import (
@ -55,6 +53,7 @@ from llama_stack.models.llama.llama3.prompt_templates import (
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.utils.inference import supported_inference_models
log = get_logger(name=__name__, category="inference")

View file

@ -224,9 +224,9 @@ exclude = [
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/config\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/llama3/generation\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/llama4/",
"^llama_stack/models/llama/llama3/generation\\.py$",
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/models/llama/llama4/",
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$",

View file

@ -11,7 +11,6 @@ import pytest
from pytest import CollectReport
from termcolor import cprint
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_list import (
all_registered_models,
llama3_1_instruct_models,
@ -20,6 +19,7 @@ from llama_stack.models.llama.sku_list import (
llama3_instruct_models,
safety_models,
)
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import Api
from .metadata import API_MAPS