refactor: move all llama code to models/llama out of meta reference (#1887)

# What does this PR do?

Move around bits. This makes the copies from llama-models _much_ easier
to maintain and ensures we don't entangle meta-reference specific
tidbits into llama-models code even by accident.

Also, kills the meta-reference-quantized-gpu distro and rolls
quantization deps into meta-reference-gpu.

## Test Plan

```
LLAMA_MODELS_DEBUG=1 \
  with-proxy llama stack run meta-reference-gpu \
  --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \
   --env INFERENCE_CHECKPOINT_DIR=<DIR> \
   --env MODEL_PARALLEL_SIZE=4 \
   --env QUANTIZATION_TYPE=fp8_mixed
```

Start a server with and without quantization. Point integration tests to
it using:

```
pytest -s -v  tests/integration/inference/test_text_inference.py \
   --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct
```
This commit is contained in:
Ashwin Bharambe 2025-04-07 15:03:58 -07:00 committed by GitHub
parent c52ccc4bbd
commit 530d4bdfe1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
85 changed files with 1267 additions and 1683 deletions

View file

@ -25,15 +25,64 @@ from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
SamplingParams,
StopReason, StopReason,
ToolCall, ToolCall,
ToolDefinition, ToolDefinition,
ToolParamDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
register_schema(ToolCall)
register_schema(ToolParamDefinition)
register_schema(ToolDefinition)
@json_schema_type
class GreedySamplingStrategy(BaseModel):
type: Literal["greedy"] = "greedy"
@json_schema_type
class TopPSamplingStrategy(BaseModel):
type: Literal["top_p"] = "top_p"
temperature: Optional[float] = Field(..., gt=0.0)
top_p: Optional[float] = 0.95
@json_schema_type
class TopKSamplingStrategy(BaseModel):
type: Literal["top_k"] = "top_k"
top_k: int = Field(..., ge=1)
SamplingStrategy = Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
]
register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type
class SamplingParams(BaseModel):
"""Sampling parameters.
:param strategy: The sampling strategy.
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
your prompt plus max_tokens cannot exceed the model's context length.
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
:param stop: Up to 4 sequences where the API will stop generating further tokens.
The returned text will not contain the stop sequence.
"""
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
class LogProbConfig(BaseModel): class LogProbConfig(BaseModel):
""" """
@ -48,18 +97,18 @@ class QuantizationType(Enum):
"""Type of model quantization to run inference with. """Type of model quantization to run inference with.
:cvar bf16: BFloat16 typically this means _no_ quantization :cvar bf16: BFloat16 typically this means _no_ quantization
:cvar fp8: 8-bit floating point quantization :cvar fp8_mixed: 8-bit floating point quantization with mixed precision
:cvar int4: 4-bit integer quantization :cvar int4_mixed: 4-bit integer quantization with mixed precision
""" """
bf16 = "bf16" bf16 = "bf16"
fp8 = "fp8" fp8_mixed = "fp8_mixed"
int4 = "int4" int4_mixed = "int4_mixed"
@json_schema_type @json_schema_type
class Fp8QuantizationConfig(BaseModel): class Fp8QuantizationConfig(BaseModel):
type: Literal["fp8"] = "fp8" type: Literal["fp8_mixed"] = "fp8_mixed"
@json_schema_type @json_schema_type
@ -75,7 +124,7 @@ class Int4QuantizationConfig(BaseModel):
:param scheme: Quantization scheme to use. Defaults to "int4_weight_int8_dynamic_activation" :param scheme: Quantization scheme to use. Defaults to "int4_weight_int8_dynamic_activation"
""" """
type: Literal["int4"] = "int4" type: Literal["int4_mixed"] = "int4_mixed"
scheme: Optional[str] = "int4_weight_int8_dynamic_activation" scheme: Optional[str] = "int4_weight_int8_dynamic_activation"

View file

@ -29,8 +29,8 @@ from rich.progress import (
from termcolor import cprint from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import LlamaDownloadInfo from llama_stack.models.llama.sku_list import LlamaDownloadInfo
from llama_stack.models.llama.sku_types import Model
class Download(Subcommand): class Download(Subcommand):

View file

@ -63,17 +63,6 @@ class ModelDescribe(Subcommand):
("Model params.json", json.dumps(model.arch_args, indent=4)), ("Model params.json", json.dumps(model.arch_args, indent=4)),
] ]
if model.recommended_sampling_params is not None:
sampling_params = model.recommended_sampling_params.model_dump()
for k in ("max_tokens", "repetition_penalty"):
del sampling_params[k]
rows.append(
(
"Recommended sampling params",
json.dumps(sampling_params, indent=4),
)
)
print_table( print_table(
rows, rows,
headers, headers,

View file

@ -11,7 +11,7 @@ from pathlib import Path
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table from llama_stack.cli.table import print_table
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family from llama_stack.models.llama.sku_types import CoreModelId, ModelFamily, is_multimodal, model_family
ROOT_DIR = Path(__file__).parent.parent.parent ROOT_DIR = Path(__file__).parent.parent.parent

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, Optional from typing import Any, Dict
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams
from llama_stack.models.llama.sku_list import LlamaDownloadInfo from llama_stack.models.llama.sku_list import LlamaDownloadInfo
from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
class PromptGuardModel(BaseModel): class PromptGuardModel(BaseModel):
@ -23,7 +23,6 @@ class PromptGuardModel(BaseModel):
is_instruct_model: bool = False is_instruct_model: bool = False
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16 quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
arch_args: Dict[str, Any] = Field(default_factory=dict) arch_args: Dict[str, Any] = Field(default_factory=dict)
recommended_sampling_params: Optional[SamplingParams] = None
def descriptor(self) -> str: def descriptor(self) -> str:
return self.model_id return self.model_id

View file

@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import concurrent.futures
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
if new_mp_size % old_mp_size == 0:
# Read old MP shard and split it into smaller ones
return [new_mp_rank * old_mp_size // new_mp_size]
elif old_mp_size % new_mp_size == 0:
# Merge old MP shards into a single one
mp_factor = old_mp_size // new_mp_size
return list(range(new_mp_rank * mp_factor, (new_mp_rank + 1) * mp_factor))
else:
raise ValueError(
f"Either old MP size or new MP size should be a multiple of the other: "
f"{old_mp_size} % {new_mp_size} != 0 and {new_mp_size} % {old_mp_size} != 0"
)
def maybe_reshard_state_dict(
ckpt_paths: List[Path],
n_kv_heads: int,
moe_num_experts: Optional[int] = None,
map_location: Union[str, torch.device] = "cpu",
mmap: bool = True,
) -> Dict[str, torch.Tensor]:
if str(map_location) == "cpu":
torch.set_default_tensor_type(torch.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
ckpt_paths = np.array(sorted(ckpt_paths))
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
old_mp_size = len(ckpt_paths)
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
paths = ckpt_paths[old_mp_ranks] # type: ignore
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
if new_mp_size == old_mp_size:
return state_dicts[0] # type: ignore
if moe_num_experts is not None:
state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts]
print(f"Resharding {len(state_dicts)} state dicts from MP size {old_mp_size} to MP size {new_mp_size}")
return reshard_mp(
state_dicts,
size=max(new_mp_size // old_mp_size, 1),
rank=new_mp_rank % max(new_mp_size // old_mp_size, 1),
repeat_qk_qv=max(new_mp_size // n_kv_heads, 1),
)
_WEIGHT_ROW_KEY = {
"feed_forward.w2",
"feed_forward.mlp.fc2",
"attention.wo",
"feed_forward.mlp.fc2_weight",
"feed_forward.w_out_shared_DF.weight",
"attn.wo.weight",
"mlp.c_proj.weight",
}
_MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"}
_WEIGHT_COLUMN_KEY = {
"output",
"feed_forward.(w1|w3)",
"feed_forward.mlp.(fc1|fc3)",
"feed_forward.mlp.fc1_weight",
"attention.(wk|wq|wv|wqkv).weight",
"feed_forward.(w_in_shared_FD|w_swiglu_FD)",
"attn.(wk|wq|wv).weight",
"attn.(wk|wq|wv).bias",
"mlp.c_fc.weight",
"mlp.c_fc.bias",
"conv1._linear.weight",
"tok_embeddings.weight",
"vision_projection.weight",
}
_MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
def reshard_mp(
state_dicts: List[Dict[str, torch.Tensor]],
size: int,
rank: int,
repeat_qk_qv: int = 1,
) -> Dict[str, torch.Tensor]:
"""
Reshard a list of state dicts into a single state dict given a change in MP size.
If the list has more than one state dict, we concatenate the values of the same
key across all state dicts. Otherwise, we just slice it for the current MP rank.
"""
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
if len(tensors) > 1:
return torch.cat(tensors, dim=dim)
return tensors[0].chunk(size, dim=dim)[rank].clone()
def process_key(key: str) -> torch.Tensor:
if row_regex.search(key):
return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
elif column_regex.search(key):
if "w13" in key or "fc1_weight" in key:
dims = state_dicts[0][key].size()
values = [s[key].view(2, dims[0] // 2, *dims[1:]) for s in state_dicts]
return concat_or_chunk(values, dim=1).flatten(0, 1)
elif "qkv" in key:
q_dim = state_dicts[0][key.replace("qkv", "o")].size(1)
kv_dim = (state_dicts[0][key].size(0) - q_dim) // 2
values = [s[key].split((q_dim, kv_dim, kv_dim)) for s in state_dicts]
return torch.cat([concat_or_chunk(x, dim=0) for x in zip(*values, strict=False)]) # type: ignore
elif "wk.weight" in key or "wv.weight" in key:
# Support MP > #kv_head
return concat_or_chunk([s[key].repeat(repeat_qk_qv, 1) for s in state_dicts], dim=0)
elif key == "output.bias" or key == "fc.weight":
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
elif "w_" in key:
return concat_or_chunk([s[key] for s in state_dicts], dim=-2)
else:
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
else:
return state_dicts[0][key].clone()
row_keys = _WEIGHT_ROW_KEY | _MOE_WEIGHT_ROW_KEY
column_keys = _WEIGHT_COLUMN_KEY | _MOE_WEIGHT_COLUMN_KEY
column_regex = re.compile("|".join(column_keys))
row_regex = re.compile("|".join(row_keys))
output: Dict[str, torch.Tensor] = {}
with concurrent.futures.ThreadPoolExecutor() as executor:
# Note: only processes keys in the first state dict.
# Assumes keys are the same across all state dicts.
mappings = {executor.submit(process_key, key): key for key in state_dicts[0]}
for future in concurrent.futures.as_completed(mappings):
output[mappings[future]] = future.result()
return output
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
routed_regex = re.compile("|".join(routed_keys))
keys = list(state_dict.keys())
for key in keys:
if routed_regex.search(key):
state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0)
return state_dict

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import base64 import base64
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
@ -19,8 +12,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema
# The goal is that these set of types are relevant for all Llama models. # The goal is that these set of types are relevant for all Llama models.
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to # That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
# the llama3 series of models. # the llama3 series of models.
@ -98,6 +89,29 @@ class StopReason(Enum):
out_of_tokens = "out_of_tokens" out_of_tokens = "out_of_tokens"
class ToolParamDefinition(BaseModel):
param_type: str
description: Optional[str] = None
required: Optional[bool] = True
default: Optional[Any] = None
class ToolDefinition(BaseModel):
tool_name: Union[BuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
class RawMediaItem(BaseModel): class RawMediaItem(BaseModel):
type: Literal["image"] = "image" type: Literal["image"] = "image"
data: bytes | BytesIO data: bytes | BytesIO
@ -140,292 +154,25 @@ class RawMessage(BaseModel):
tool_calls: List[ToolCall] = Field(default_factory=list) tool_calls: List[ToolCall] = Field(default_factory=list)
register_schema(ToolCall) class GenerationResult(BaseModel):
token: int
text: str
logprobs: Optional[List[float]] = None
source: Literal["input"] | Literal["output"]
# index within the batch
batch_idx: int
# whether generation for this item is already finished. note that tokens can
# get returned even afterwards since other items in the batch can still be generating tokens
finished: bool
# because a batch is parallel processed, useful decoding for one item can correspond to processing
# pad tokens or tokens beyond EOS for other items. we could have decided to return None for this case
# but it's more convenient to return a list of GenerationResult and filter out the ignored tokens
ignore_token: bool
@json_schema_type class QuantizationMode(str, Enum):
class ToolParamDefinition(BaseModel): none = "none"
param_type: str fp8_mixed = "fp8_mixed"
description: Optional[str] = None int4_mixed = "int4_mixed"
required: Optional[bool] = True
default: Optional[Any] = None
@json_schema_type
class ToolDefinition(BaseModel):
tool_name: Union[BuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
@json_schema_type
class GreedySamplingStrategy(BaseModel):
type: Literal["greedy"] = "greedy"
@json_schema_type
class TopPSamplingStrategy(BaseModel):
type: Literal["top_p"] = "top_p"
temperature: Optional[float] = Field(..., gt=0.0)
top_p: Optional[float] = 0.95
@json_schema_type
class TopKSamplingStrategy(BaseModel):
type: Literal["top_k"] = "top_k"
top_k: int = Field(..., ge=1)
SamplingStrategy = Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
]
register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type
class SamplingParams(BaseModel):
"""Sampling parameters.
:param strategy: The sampling strategy.
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
your prompt plus max_tokens cannot exceed the model's context length.
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
:param stop: Up to 4 sequences where the API will stop generating further tokens.
The returned text will not contain the stop sequence.
"""
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
class CheckpointQuantizationFormat(Enum):
# default format
bf16 = "bf16"
# used for enabling fp8_rowwise inference, some weights are bf16
fp8_mixed = "fp8-mixed"
int8 = "int8"
int4 = "int4"
class ModelFamily(Enum):
llama2 = "llama2"
llama3 = "llama3"
llama3_1 = "llama3_1"
llama3_2 = "llama3_2"
llama3_3 = "llama3_3"
llama4 = "llama4"
safety = "safety"
class CoreModelId(Enum):
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
# Llama 2 family
llama2_7b = "Llama-2-7b"
llama2_13b = "Llama-2-13b"
llama2_70b = "Llama-2-70b"
llama2_7b_chat = "Llama-2-7b-chat"
llama2_13b_chat = "Llama-2-13b-chat"
llama2_70b_chat = "Llama-2-70b-chat"
# Llama 3 family
llama3_8b = "Llama-3-8B"
llama3_70b = "Llama-3-70B"
llama3_8b_instruct = "Llama-3-8B-Instruct"
llama3_70b_instruct = "Llama-3-70B-Instruct"
# Llama 3.1 family
llama3_1_8b = "Llama3.1-8B"
llama3_1_70b = "Llama3.1-70B"
llama3_1_405b = "Llama3.1-405B"
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
# Llama 3.2 family
llama3_2_1b = "Llama3.2-1B"
llama3_2_3b = "Llama3.2-3B"
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
llama3_2_11b_vision = "Llama3.2-11B-Vision"
llama3_2_90b_vision = "Llama3.2-90B-Vision"
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
# Llama 3.3 family
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
# Llama 4 family
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
# Safety models
llama_guard_3_8b = "Llama-Guard-3-8B"
llama_guard_2_8b = "Llama-Guard-2-8B"
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
llama_guard_3_1b = "Llama-Guard-3-1B"
def is_multimodal(model_id) -> bool:
if model_id in [
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return True
else:
return False
def model_family(model_id) -> ModelFamily:
if model_id in [
CoreModelId.llama2_7b,
CoreModelId.llama2_13b,
CoreModelId.llama2_70b,
CoreModelId.llama2_7b_chat,
CoreModelId.llama2_13b_chat,
CoreModelId.llama2_70b_chat,
]:
return ModelFamily.llama2
elif model_id in [
CoreModelId.llama3_8b,
CoreModelId.llama3_70b,
CoreModelId.llama3_8b_instruct,
CoreModelId.llama3_70b_instruct,
]:
return ModelFamily.llama3
elif model_id in [
CoreModelId.llama3_1_8b,
CoreModelId.llama3_1_70b,
CoreModelId.llama3_1_405b,
CoreModelId.llama3_1_8b_instruct,
CoreModelId.llama3_1_70b_instruct,
CoreModelId.llama3_1_405b_instruct,
]:
return ModelFamily.llama3_1
elif model_id in [
CoreModelId.llama3_2_1b,
CoreModelId.llama3_2_3b,
CoreModelId.llama3_2_1b_instruct,
CoreModelId.llama3_2_3b_instruct,
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return ModelFamily.llama3_2
elif model_id in [
CoreModelId.llama3_3_70b_instruct,
]:
return ModelFamily.llama3_3
elif model_id in [
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_scout_17b_16e_instruct,
CoreModelId.llama4_maverick_17b_128e,
CoreModelId.llama4_maverick_17b_128e_instruct,
]:
return ModelFamily.llama4
elif model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_2_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return ModelFamily.safety
else:
raise ValueError(f"Unknown model family for {model_id}")
class Model(BaseModel):
core_model_id: CoreModelId
description: str
huggingface_repo: Optional[str] = None
recommended_sampling_params: Optional[SamplingParams] = None
arch_args: Dict[str, Any]
variant: str = ""
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
pth_file_count: int
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
# silence pydantic until we remove the `model_` fields
model_config = ConfigDict(protected_namespaces=())
@property
def model_family(self) -> ModelFamily:
return model_family(self.core_model_id)
# The SKU is uniquely identified by (model_id, variant) combo
def descriptor(self, shorten_default_variant: bool = True) -> str:
if not self.variant:
return self.core_model_id.value
return f"{self.core_model_id.value}:{self.variant}"
@property
def is_instruct_model(self) -> bool:
return "instruct" in self.id.name
# Featured models are shown in the non-exhaustive model list
@property
def is_featured(self) -> bool:
return self.model_family in [
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
ModelFamily.safety,
]
@property
def max_seq_length(self) -> int:
if self.model_family == ModelFamily.llama2:
return 4096
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
return 4096
elif self.model_family == ModelFamily.llama3:
return 8192
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
return 131072
elif self.model_family == ModelFamily.llama3_2:
if self.quantization_format == CheckpointQuantizationFormat.int4:
return 8192
return 131072
elif self.model_family == ModelFamily.llama4:
if self.core_model_id in {
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_maverick_17b_128e,
}:
return 262144
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
return 10485760
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
return 1048576
elif self.core_model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return 131072
else:
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import io import io
import json import json
import uuid import uuid
@ -19,7 +12,7 @@ from typing import Dict, List, Optional, Tuple
from PIL import Image as PIL_Image from PIL import Image as PIL_Image
from llama_stack.models.llama.datatypes import ( from ..datatypes import (
BuiltinTool, BuiltinTool,
RawContent, RawContent,
RawMediaItem, RawMediaItem,
@ -30,7 +23,6 @@ from llama_stack.models.llama.datatypes import (
ToolCall, ToolCall,
ToolPromptFormat, ToolPromptFormat,
) )
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .tool_utils import ToolUtils from .tool_utils import ToolUtils

View file

@ -0,0 +1,367 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import json
import os
import sys
import time
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel,
model_parallel_is_initialized,
)
from termcolor import cprint
from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
from .args import ModelArgs
from .chat_format import ChatFormat, LLMInput
from .model import Transformer
from .multimodal.model import CrossAttentionTransformer
from .tokenizer import Tokenizer
class Llama3:
@staticmethod
def build(
ckpt_dir: str,
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
quantization_mode: Optional[QuantizationMode] = None,
seed: int = 1,
device: str = "cuda",
):
device = torch.device(device)
if (
device.type == "cuda"
and not torch.cuda.is_available()
or device.type == "xpu"
and not torch.xpu.is_available()
):
raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")
if not torch.distributed.is_initialized():
if device.type == "cuda":
torch.distributed.init_process_group("nccl")
else:
torch.distributed.init_process_group("gloo")
if not model_parallel_is_initialized():
if world_size is None:
world_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(world_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if device.type == "cuda":
torch.cuda.set_device(local_rank)
elif device.type == "xpu":
torch.xpu.set_device(local_rank)
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
tokenizer = Tokenizer.get_instance()
state_dict = maybe_reshard_state_dict(
ckpt_paths,
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
)
assert model_args.vocab_size == tokenizer.n_words
def build_model():
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
else:
model = Transformer(model_args)
return model
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
from .quantization.loader import convert_to_quantized_model
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = build_model()
print("Loading state dict...")
model.load_state_dict(state_dict, strict=False)
print("Done...")
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
torch.set_default_device(device)
else:
print(f"Setting default device to {device}")
torch.set_default_device(device)
if device.type == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
elif device.type == "xpu":
if torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
model = build_model()
print("Loading state dict...")
model.load_state_dict(state_dict, strict=True)
model.to(device)
print("Done...")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama3(model, tokenizer, model_args)
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
@torch.inference_mode()
def generate(
self,
model_inputs: List[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator[List[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
params = self.model.params
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input:
for inp in model_inputs:
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
prompt_tokens = [inp.tokens for inp in model_inputs]
bsz = len(model_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
is_vision = not isinstance(self.model, Transformer)
if is_vision:
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=images,
batch_masks=mask,
total_len=total_len,
device=tokens.device,
)
eos_reached = torch.tensor([False] * bsz)
input_text_mask = tokens != pad_id
if echo:
for i in range(max_prompt_len):
results = []
for j, t in enumerate(tokens[:, i]):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="input",
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
batch_idx=j,
finished=False,
ignore_token=t.item() == pad_id,
)
)
yield results
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
prev_pos = 0
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
text_only_inference = all(inp.vision is None for inp in model_inputs)
logits = self.model.forward(
position_ids,
tokens,
cross_attention_masks,
full_text_row_masked_out_mask,
xattn_caches,
text_only_inference,
)
else:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if logits_processor is not None:
logits = logits_processor(tokens[:, :cur_pos], logits)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if is_vision:
# the logits space (num_classes) is designed to never contain a media_token
# however our input token stream does contain them. we need to nuke them here
# or else the CUDA kernels will crash with an illegal memory access
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
masks = [target.eq(t) for t in vision_tokens]
if len(masks) > 1:
mask = torch.logical_or(*masks)
else:
mask = masks[0]
target[mask] = 0
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=target,
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
results = []
for idx, t in enumerate(next_token):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)
yield results
prev_pos = cur_pos
if all(eos_reached):
break
def completion(
self,
contents: List[RawContent],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate(
model_inputs=model_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
yield result
if all(r.finished for r in result):
break
def chat_completion(
self,
messages_batch: List[List[RawMessage]],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate(
model_inputs=model_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
yield result
if all(r.finished for r in result):
break
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

View file

@ -16,7 +16,7 @@ from typing import List, Optional
from termcolor import colored from termcolor import colored
from llama_stack.models.llama.datatypes import ( from ..datatypes import (
BuiltinTool, BuiltinTool,
RawMessage, RawMessage,
StopReason, StopReason,
@ -24,7 +24,6 @@ from llama_stack.models.llama.datatypes import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from . import template_data from . import template_data
from .chat_format import ChatFormat from .chat_format import ChatFormat
from .prompt_templates import ( from .prompt_templates import (

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple
@ -29,6 +19,10 @@ from torch import nn
from .args import ModelArgs from .args import ModelArgs
# **NOTE**: This code is not runnable without installing `torch` and `fairscale`
# dependencies. These dependencies are not part of the default dependencies
# (requirements.txt) of the `llama-models` package.
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6): def __init__(self, dim: int, eps: float = 1e-6):
@ -111,9 +105,9 @@ class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size() world_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size self.n_local_heads = args.n_heads // world_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_local_kv_heads = self.n_kv_heads // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads self.head_dim = args.dim // args.n_heads

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import logging import logging
import math import math
from functools import partial from functools import partial
@ -180,14 +170,14 @@ class ImageAttention(nn.Module):
n_heads, n_heads,
): ):
super().__init__() super().__init__()
model_parallel_size = fs_init.get_model_parallel_world_size() world_size = fs_init.get_model_parallel_world_size()
qkvo_replication = 1 qkvo_replication = 1
if model_parallel_size > 16: if world_size > 16:
qkvo_replication = model_parallel_size // 8 qkvo_replication = world_size // 8
self.n_kv_heads = n_heads self.n_kv_heads = n_heads
self.n_local_heads = n_heads * qkvo_replication // model_parallel_size self.n_local_heads = n_heads * qkvo_replication // world_size
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // model_parallel_size self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = dim // n_heads self.head_dim = dim // n_heads
@ -536,16 +526,16 @@ class Attention(nn.Module):
cache_v (torch.Tensor): Cached values for attention. cache_v (torch.Tensor): Cached values for attention.
""" """
super().__init__() super().__init__()
model_parallel_size = fs_init.get_model_parallel_world_size() world_size = fs_init.get_model_parallel_world_size()
replication_factor = 1 replication_factor = 1
if model_parallel_size > 8: if world_size > 8:
replication_factor = model_parallel_size // MP_SCALE replication_factor = world_size // MP_SCALE
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_kv_heads *= replication_factor self.n_kv_heads *= replication_factor
self.n_local_heads = args.n_heads // model_parallel_size self.n_local_heads = args.n_heads // world_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_local_kv_heads = self.n_kv_heads // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads self.head_dim = args.dim // args.n_heads
self.max_seq_len = args.max_seq_len self.max_seq_len = args.max_seq_len
@ -587,13 +577,11 @@ class Attention(nn.Module):
self.n_local_kv_heads, self.n_local_kv_heads,
self.head_dim, self.head_dim,
) )
device = next(self.parameters()).device
self.register_buffer( self.register_buffer(
"key_cache", "key_cache",
torch.zeros( torch.zeros(
cache_shape, cache_shape,
dtype=dtype, dtype=dtype,
device=device,
), ),
persistent=False, persistent=False,
) )
@ -602,7 +590,6 @@ class Attention(nn.Module):
torch.zeros( torch.zeros(
cache_shape, cache_shape,
dtype=dtype, dtype=dtype,
device=device,
), ),
persistent=False, persistent=False,
) )
@ -614,6 +601,9 @@ class Attention(nn.Module):
freqs_cis: torch.Tensor, freqs_cis: torch.Tensor,
position_ids: torch.LongTensor, position_ids: torch.LongTensor,
): ):
self.key_cache = self.key_cache.to(x.device)
self.value_cache = self.value_cache.to(x.device)
xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]] xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
bs, slen, _ = xq.shape bs, slen, _ = xq.shape
@ -832,10 +822,10 @@ class CrossAttention(torch.nn.Module):
norm_eps: float, norm_eps: float,
): ):
super().__init__() super().__init__()
self.model_parallel_size = fs_init.get_model_parallel_world_size() self.world_size = fs_init.get_model_parallel_world_size()
replication_factor = 1 replication_factor = 1
if self.model_parallel_size > 8: if self.world_size > 8:
replication_factor = self.model_parallel_size // MP_SCALE replication_factor = self.world_size // MP_SCALE
n_kv_heads *= replication_factor n_kv_heads *= replication_factor
assert n_heads % n_kv_heads == 0 assert n_heads % n_kv_heads == 0
@ -889,10 +879,10 @@ class CrossAttention(torch.nn.Module):
# trunk LLM (i.e., group query attention) -- @dubeya # trunk LLM (i.e., group query attention) -- @dubeya
# local heads # local heads
assert self.n_heads % self.n_kv_heads == 0 assert self.n_heads % self.n_kv_heads == 0
assert self.n_heads % self.model_parallel_size == 0 assert self.n_heads % self.world_size == 0
assert self.n_kv_heads % self.model_parallel_size == 0 assert self.n_kv_heads % self.world_size == 0
self.n_local_heads = self.n_heads // self.model_parallel_size self.n_local_heads = self.n_heads // self.world_size
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size self.n_local_kv_heads = self.n_kv_heads // self.world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads self.n_rep = self.n_local_heads // self.n_local_kv_heads
def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
@ -1041,7 +1031,7 @@ class CrossAttentionTransformerVision(torch.nn.Module):
self.image_res = args.vision_chunk_size self.image_res = args.vision_chunk_size
self.max_num_chunks = args.vision_max_num_chunks self.max_num_chunks = args.vision_max_num_chunks
if return_intermediate is not None: if return_intermediate is not None:
return_intermediate = [int(level) for level in return_intermediate.split(",")] return_intermediate = [int(layer) for layer in return_intermediate.split(",")]
self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
self.patch_size = 14 self.patch_size = 14
self.vision_encoder = VisionEncoder( self.vision_encoder = VisionEncoder(
@ -1076,15 +1066,15 @@ class CrossAttentionTransformerText(torch.nn.Module):
def __init__(self, args: ModelArgs) -> None: def __init__(self, args: ModelArgs) -> None:
super().__init__() super().__init__()
self.model_parallel_size = fs_init.get_model_parallel_world_size() self.world_size = fs_init.get_model_parallel_world_size()
assert args.vocab_size > 0 assert args.vocab_size > 0
self.vocab_size = args.vocab_size self.vocab_size = args.vocab_size
self.n_layers = args.n_layers self.n_layers = args.n_layers
self.dim = args.dim self.dim = args.dim
self.head_dim = args.dim // args.n_heads self.head_dim = args.dim // args.n_heads
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size self.n_local_kv_heads = self.n_kv_heads // self.world_size
assert self.vocab_size % self.model_parallel_size == 0 assert self.vocab_size % self.world_size == 0
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x) self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
self.pos_embeddings = None self.pos_embeddings = None
# final norm layer (not necessary for post-norm) # final norm layer (not necessary for post-norm)
@ -1184,6 +1174,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
text_only_inference: bool = False, text_only_inference: bool = False,
): ):
assert self.cache_is_setup, "Please set up cache before calling forward" assert self.cache_is_setup, "Please set up cache before calling forward"
self.mask_cache = self.mask_cache.to(h.device)
self.freqs_cis = self.freqs_cis.to(h.device)
mask = self.mask_cache.index_select(2, position_ids) mask = self.mask_cache.index_select(2, position_ids)
freqs_cis = self.freqs_cis.index_select(0, position_ids) freqs_cis = self.freqs_cis.index_select(0, position_ids)
@ -1212,9 +1204,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
output = gather_from_tensor_model_parallel_region(output) output = gather_from_tensor_model_parallel_region(output)
return output.float() return output.float()
def setup_cache(self, max_batch_size: int, dtype=torch.bfloat16): def setup_cache(self, max_batch_size: int, device: torch.device, dtype=torch.bfloat16):
# Set up the text kv caches # Set up the text kv caches
device = next(self.parameters()).device
ones = torch.ones( ones = torch.ones(
(self.max_seq_len, self.max_seq_len), (self.max_seq_len, self.max_seq_len),
dtype=torch.bool, dtype=torch.bool,
@ -1265,7 +1256,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
return ( return (
cross_attention_masks.to(device=text_device, dtype=text_dtype), cross_attention_masks.to(device=text_device, dtype=text_dtype),
full_text_row_masked_out_mask, full_text_row_masked_out_mask.to(device=text_device),
) )
@ -1284,14 +1275,15 @@ class CrossAttentionTransformer(torch.nn.Module):
max_num_chunks=args.vision_max_num_chunks, max_num_chunks=args.vision_max_num_chunks,
) )
def setup_cache(self, max_batch_size: int, dtype: torch.dtype): def setup_cache(self, max_batch_size: int, device: torch.device, dtype: torch.dtype):
self.text_model.setup_cache(max_batch_size, dtype) self.text_model.setup_cache(max_batch_size, device, dtype)
def compute_vision_tokens_masks( def compute_vision_tokens_masks(
self, self,
batch_images: List[List[PIL_Image.Image]], batch_images: List[List[PIL_Image.Image]],
batch_masks: List[List[List[int]]], batch_masks: List[List[List[int]]],
total_len: int, total_len: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
skip_vision_encoder = False skip_vision_encoder = False
@ -1318,6 +1310,7 @@ class CrossAttentionTransformer(torch.nn.Module):
image_res=self.params.vision_chunk_size, image_res=self.params.vision_chunk_size,
max_num_images=max_num_images, max_num_images=max_num_images,
) )
stacked_images = stacked_images.to(device=device)
if skip_vision_encoder: if skip_vision_encoder:
vision_tokens = torch.zeros( vision_tokens = torch.zeros(
@ -1330,7 +1323,7 @@ class CrossAttentionTransformer(torch.nn.Module):
), ),
) )
else: else:
vision_tokens = self.vision_model(stacked_images, aspect_ratios) vision_tokens = self.vision_model(stacked_images, aspect_ratios).to(device=device)
bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape) bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
xattn_caches = torch.stack( xattn_caches = torch.stack(

View file

@ -15,7 +15,7 @@ import textwrap
from datetime import datetime from datetime import datetime
from typing import Any, List, Optional from typing import Any, List, Optional
from llama_stack.models.llama.datatypes import ( from llama_stack.apis.inference import (
BuiltinTool, BuiltinTool,
ToolDefinition, ToolDefinition,
ToolParamDefinition, ToolParamDefinition,

View file

@ -3,5 +3,3 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .meta_reference import get_distribution_template # noqa: F401

View file

@ -4,9 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
# type: ignore # type: ignore
import os import os
from typing import Any, Dict, List, Optional, cast from typing import Any, Dict, List, Optional, cast
@ -18,22 +15,15 @@ from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_regi
from torch import Tensor, nn from torch import Tensor, nn
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType from ...datatypes import QuantizationMode
from llama_stack.log import get_logger from ...quantize_impls import (
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.inference.meta_reference.quantize_impls import (
Fp8ScaledWeights, Fp8ScaledWeights,
ffn_swiglu, ffn_swiglu,
load_fp8, load_fp8,
quantize_fp8, quantize_fp8,
) )
from ...config import MetaReferenceQuantizedInferenceConfig
from ..args import ModelArgs
from ..model import Transformer, TransformerBlock from ..model import Transformer, TransformerBlock
from ..multimodal.model import CrossAttentionTransformer
log = get_logger(__name__, category="quantization")
def swiglu_wrapper( def swiglu_wrapper(
@ -44,30 +34,34 @@ def swiglu_wrapper(
return reduce_from_model_parallel_region(out) return reduce_from_model_parallel_region(out)
def convert_to_quantized_model(
model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str,
quantization_mode: Optional[str] = None,
fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
) -> Transformer | CrossAttentionTransformer:
if quantization_mode == QuantizationMode.fp8_mixed:
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
elif quantization_mode == QuantizationMode.int4_mixed:
return convert_to_int4_quantized_model(model, checkpoint_dir, device)
else:
raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
def convert_to_fp8_quantized_model( def convert_to_fp8_quantized_model(
model: Transformer, model: Transformer,
config: MetaReferenceQuantizedInferenceConfig,
checkpoint_dir: str, checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0, fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
) -> Transformer: ) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
return model
elif config.quantization.type != QuantizationType.fp8.value:
raise ValueError("Only FP8 quantization is supported")
assert config.model is not None, "Model must be specified for quantized inference"
llama_model = resolve_model(config.model)
assert llama_model is not None, f"Model {config.model} not found"
# Move weights to GPU with quantization # Move weights to GPU with quantization
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value: fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
log.info("Loading fp8 scales...") if os.path.isfile(fp8_scales_path):
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt") print("Loading fp8 scales...")
assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True) fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers: for _, block in model.named_modules():
if isinstance(block, TransformerBlock): if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue continue
@ -81,8 +75,8 @@ def convert_to_fp8_quantized_model(
fp8_activation_scale_ub, fp8_activation_scale_ub,
) )
else: else:
log.info("Quantizing fp8 weights from bf16...") print("Quantizing fp8 weights from bf16...")
for block in model.layers: for _, block in model.named_modules():
if isinstance(block, TransformerBlock): if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue continue
@ -92,12 +86,12 @@ def convert_to_fp8_quantized_model(
param.weight = quantize_fp8( param.weight = quantize_fp8(
param.weight, param.weight,
fp8_activation_scale_ub, fp8_activation_scale_ub,
output_device=torch.device("cuda"), output_device=device,
) )
for _, parameter in model.named_parameters(): for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights): if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda") parameter.data = parameter.to(device=device)
return model return model
@ -290,12 +284,12 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
def convert_to_int4_quantized_model( def convert_to_int4_quantized_model(
model: Transformer, model: Transformer | CrossAttentionTransformer,
model_args: ModelArgs, checkpoint_dir: str,
config: MetaReferenceQuantizedInferenceConfig, device: Optional[torch.device] = None,
) -> Transformer: ) -> Transformer | CrossAttentionTransformer:
"""Convert the model to int4 quantized model.""" """Convert the model to int4 quantized model."""
model_args = model.params
assert model_args.quantization_args is not None, "Quantization args must be specified." assert model_args.quantization_args is not None, "Quantization args must be specified."
quantization_args = model_args.quantization_args quantization_args = model_args.quantization_args
if quantization_args.scheme is None: if quantization_args.scheme is None:
@ -319,5 +313,4 @@ def convert_to_int4_quantized_model(
lora_scale = model_args.lora_args.scale lora_scale = model_args.lora_args.scale
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale) _prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") return cast(Transformer | CrossAttentionTransformer, model.to(device=device))
return cast(Transformer, model.to(device))

View file

@ -12,8 +12,7 @@
# the top-level of this source tree. # the top-level of this source tree.
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from ..datatypes import BuiltinTool, StopReason, ToolCall
from .prompt_templates import ( from .prompt_templates import (
BuiltinToolGenerator, BuiltinToolGenerator,
JsonCustomToolGenerator, JsonCustomToolGenerator,

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os import os
from logging import getLogger from logging import getLogger
from pathlib import Path from pathlib import Path

View file

@ -16,7 +16,8 @@ import re
from typing import Optional, Tuple from typing import Optional, Tuple
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference")

View file

@ -3,10 +3,3 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

View file

@ -4,12 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import json import json
import textwrap import textwrap

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap import textwrap
from pathlib import Path from pathlib import Path

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional

View file

@ -13,7 +13,7 @@ import torch
from PIL import Image as PIL_Image from PIL import Image as PIL_Image
# TODO: either fork these or move them to the common package # TODO: either fork these or move them to the common package
from llama_stack.models.llama.datatypes import ( from ..datatypes import (
BuiltinTool, BuiltinTool,
RawContent, RawContent,
RawMediaItem, RawMediaItem,
@ -24,16 +24,10 @@ from llama_stack.models.llama.datatypes import (
ToolCall, ToolCall,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.models.llama.llama3.tool_utils import ToolUtils from ..llama3.tool_utils import ToolUtils
from llama_stack.providers.inline.inference.meta_reference.llama4.args import VisionArgs from .args import VisionArgs
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import ( from .datatypes import LLMInput
LLMInput, from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
)
from llama_stack.providers.inline.inference.meta_reference.llama4.preprocess import (
ResizeNormalizeImageTransform,
VariableSizeImageTransform,
)
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
@ -54,7 +48,7 @@ class TransformedImage:
aspect_ratio: Tuple[int, int] aspect_ratio: Tuple[int, int]
def convert_rgba_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image: def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
if image.mode == "RGBA": if image.mode == "RGBA":
image.load() # for png.split() image.load() # for png.split()
new_img = PIL_Image.new("RGB", image.size, bg) new_img = PIL_Image.new("RGB", image.size, bg)
@ -171,7 +165,7 @@ class ChatFormat:
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
image = PIL_Image.open(bytes_io) image = PIL_Image.open(bytes_io)
image = convert_rgba_to_rgb(image) image = convert_image_to_rgb(image)
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks) image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
if image_tiles.shape[0] > 1: if image_tiles.shape[0] > 1:

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Union

View file

@ -10,40 +10,28 @@ import json
import os import os
import sys import sys
import time import time
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Callable, Generator, List, Optional from typing import Callable, Generator, List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import ( from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel, initialize_model_parallel,
model_parallel_is_initialized, model_parallel_is_initialized,
) )
from termcolor import cprint from termcolor import cprint
from llama_stack.models.llama.llama4.chat_format import ( from ..checkpoint import maybe_reshard_state_dict
ChatFormat, from ..datatypes import GenerationResult, QuantizationMode
RawContent,
RawMessage,
)
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
from ..common import TokenResult
from .args import ModelArgs from .args import ModelArgs
from .chat_format import ChatFormat, RawContent, RawMessage
from .datatypes import LLMInput, MaskedEmbedding, TransformerInput from .datatypes import LLMInput, MaskedEmbedding, TransformerInput
from .model import Transformer from .model import Transformer
from .tokenizer import Tokenizer
torch.serialization.add_safe_globals([io.BytesIO, codecs.encode]) torch.serialization.add_safe_globals([io.BytesIO, codecs.encode])
class QuantizationMode(str, Enum):
none = "none"
fp8_mixed = "fp8_mixed"
int4_mixed = "int4_mixed"
class Llama4: class Llama4:
@staticmethod @staticmethod
def build( def build(
@ -51,7 +39,7 @@ class Llama4:
max_seq_len: int, max_seq_len: int,
max_batch_size: int, max_batch_size: int,
world_size: Optional[int] = None, world_size: Optional[int] = None,
quantization_mode: Optional[str] = None, quantization_mode: Optional[QuantizationMode] = None,
seed: int = 1, seed: int = 1,
): ):
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
@ -72,11 +60,9 @@ class Llama4:
start_time = time.time() start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
assert world_size == len(checkpoints), ( print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
)
with open(Path(ckpt_dir) / "params.json", "r") as f: with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read()) params = json.loads(f.read())
@ -93,10 +79,11 @@ class Llama4:
assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch" assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch"
print("Model args:\n", model_args.model_dump_json(indent=2)) print("Model args:\n", model_args.model_dump_json(indent=2))
ckpt_path = checkpoints[get_model_parallel_rank()] state_dict = maybe_reshard_state_dict(
print(f"Loading checkpoint from {ckpt_dir}...") ckpt_paths,
with open(ckpt_path, "rb") as f: n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
checkpoint = torch.load(f, map_location="cpu", weights_only=True) moe_num_experts=model_args.moe_args.num_experts,
)
print("Loaded checkpoint") print("Loaded checkpoint")
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed: if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
from .quantization.loader import convert_to_quantized_model from .quantization.loader import convert_to_quantized_model
@ -104,9 +91,9 @@ class Llama4:
torch.set_default_tensor_type(torch.BFloat16Tensor) torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args) model = Transformer(model_args)
print("Loading state dict...") print("Loading state dict...")
model.load_state_dict(checkpoint, strict=False) model.load_state_dict(state_dict, strict=False)
print("Done...") print("Done...")
model = convert_to_quantized_model(model, ckpt_dir) model = convert_to_quantized_model(model, ckpt_dir, quantization_mode)
else: else:
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
@ -115,7 +102,7 @@ class Llama4:
model = Transformer(model_args) model = Transformer(model_args)
print("Loading state dict...") print("Loading state dict...")
model.load_state_dict(checkpoint, strict=False) model.load_state_dict(state_dict, strict=False)
print("Done...") print("Done...")
print(f"Loaded in {time.time() - start_time:.2f} seconds") print(f"Loaded in {time.time() - start_time:.2f} seconds")
@ -130,7 +117,7 @@ class Llama4:
@torch.inference_mode() @torch.inference_mode()
def generate( def generate(
self, self,
llm_input: LLMInput, llm_inputs: List[LLMInput],
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: Optional[int] = None,
@ -138,22 +125,20 @@ class Llama4:
echo: bool = False, echo: bool = False,
print_model_input: bool = False, print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator: ) -> Generator[List[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len: if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
max_gen_len = self.model.args.max_seq_len - 1 max_gen_len = self.model.args.max_seq_len - 1
params = self.model.args params = self.model.args
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1" print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input and get_model_parallel_rank() == 0: if print_model_input:
tokens_to_print = list(llm_input.tokens) cprint("Input to model:\n", "yellow")
cprint( for inp in llm_inputs:
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n", cprint(self.tokenizer.decode(inp.tokens), "grey")
"red", prompt_tokens = [inp.tokens for inp in llm_inputs]
)
prompt_tokens = [llm_input.tokens]
bsz = 1 bsz = len(llm_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens) min_prompt_len = min(len(t) for t in prompt_tokens)
@ -176,24 +161,33 @@ class Llama4:
input_text_mask = tokens != pad_id input_text_mask = tokens != pad_id
if echo: if echo:
for i, t in enumerate(llm_input.tokens): for i in range(max_prompt_len):
yield TokenResult( results = []
token=t, for j, t in enumerate(tokens[:, i]):
text=self.tokenizer.decode([t]), results.append(
logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None), GenerationResult(
) token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="input",
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
batch_idx=j,
finished=False,
ignore_token=t.item() == pad_id,
)
)
yield results
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda") stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
prev_pos = 0 prev_pos = 0
for cur_pos in range(min_prompt_len, total_len): for cur_pos in range(min_prompt_len, total_len):
image_embedding = None image_embedding = None
if prev_pos == 0 and llm_input.images is not None and len(llm_input.images) > 0: if prev_pos == 0 and any(inp.images is not None and len(inp.images) > 0 for inp in llm_inputs):
image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"] image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"]
image_mask = image_mask.unsqueeze(-1) image_mask = image_mask.unsqueeze(-1)
h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos]) h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos])
image_batch = [llm_input.images] image_batch = [inp.images if inp.images is not None else [] for inp in llm_inputs]
image_embedding = MaskedEmbedding( image_embedding = MaskedEmbedding(
embedding=self.model.vision_embeddings(image_batch, image_mask, h), embedding=self.model.vision_embeddings(image_batch, image_mask, h),
mask=image_mask, mask=image_mask,
@ -229,11 +223,21 @@ class Llama4:
ignore_index=pad_id, ignore_index=pad_id,
) )
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens)) eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
yield TokenResult(
token=next_token[0].item(), results = []
text=self.tokenizer.decode(next_token.tolist()), for idx, t in enumerate(next_token):
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None), results.append(
) GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)
yield results
prev_pos = cur_pos prev_pos = cur_pos
if all(eos_reached): if all(eos_reached):
@ -241,68 +245,47 @@ class Llama4:
def completion( def completion(
self, self,
content: RawContent, contents: List[RawContent],
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: Optional[int] = None,
logprobs: bool = False, logprobs: bool = False,
echo: bool = False, echo: bool = False,
) -> Generator: ) -> Generator[List[GenerationResult], None, None]:
llm_input = self.formatter.encode_content(content) llm_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate( for result in self.generate(
llm_input=llm_input, llm_inputs=llm_inputs,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
logprobs=logprobs, logprobs=logprobs,
echo=echo, echo=echo,
): ):
if result.token in self.tokenizer.stop_tokens:
break
yield result yield result
if all(r.finished for r in result):
break
def chat_completion( def chat_completion(
self, self,
messages: List[RawMessage], messages_batch: List[List[RawMessage]],
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: Optional[int] = None,
logprobs: bool = False, logprobs: bool = False,
echo: bool = False, echo: bool = False,
) -> Generator: ) -> Generator[List[GenerationResult], None, None]:
llm_input = self.formatter.encode_dialog_prompt(messages) llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate( for result in self.generate(
llm_input=llm_input, llm_inputs=llm_inputs,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
logprobs=logprobs, logprobs=logprobs,
echo=echo, echo=echo,
): ):
if result.token in self.tokenizer.stop_tokens:
break
yield result yield result
if all(r.finished for r in result):
def chat_completion_raw( break
self,
messages: List[RawMessage],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
):
llm_input = self.formatter.encode_dialog_prompt(messages)
output_tokens = []
for result in self.generate(
llm_input=llm_input,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
):
output_tokens.append(result.token)
return llm_input.tokens, output_tokens
def sample_top_p(probs, p): def sample_top_p(probs, p):

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import math import math
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
@ -184,7 +174,6 @@ class Attention(nn.Module):
self.head_dim, self.head_dim,
) )
).cuda() ).cuda()
self.qk_norm = None self.qk_norm = None
if self.use_qk_norm: if self.use_qk_norm:
self.qk_norm = L2Norm(args.norm_eps) self.qk_norm = L2Norm(args.norm_eps)

View file

@ -100,31 +100,21 @@ class Experts(nn.Module):
class MoE(torch.nn.Module): class MoE(torch.nn.Module):
""" """
This EC implementation is modified from the original EC module.
We refactored the token permutation and unpermutation logic and added support to tp and dp2ep sharding.
This module supports 3 sharding methods of the experts:
- tp: each TP rank has n_experts experts. Experts are sharded following the conventional row/column-parallel TP sharding.
- tp2ep: each TP rank has n_experts/tp experts. Experts are not sharded.
- dp2ep: each EP rank has n_experts/ep experts. Experts are sharded following the row/column-parallel TP sharding.
Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor. Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
Several commonly used annotations include: Several commonly used annotations include:
- a: bsz*slen - a: bsz*slen
- E: number of experts - E: number of experts
- e: number of local experts per ep (n_experts/ep) - e: number of local experts per ep (n_experts/ep)
- et: number of local experts per tp (n_experts/tp)
- D: hidden dimension - D: hidden dimension
- d: D/tp - d: D/tp
- F: model dimension - F: model dimension
- f: F/tp (used in column/row-parallel linear)
- G: number of tokens per expert (a * capacity_factor / E) - G: number of tokens per expert (a * capacity_factor / E)
- g: number of tokens per expert per TP rank (i.e., G/TP) - g: number of tokens per expert per TP rank (i.e., G/TP)
- GG: G*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=False)
- gg: g*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=True)
Examples: Examples:
x_aD [a, D] x_aD [a, D]
routed_in_etG_D [et*G, D] routed_in_etG_D [et*G, D]
x_eGGD: [e, GG, D] x_eGD: [e, G, D]
""" """
def __init__( def __init__(
@ -207,13 +197,13 @@ class MoE(torch.nn.Module):
routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1) routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
out_aD = self.shared_expert(x_aD) out_aD = self.shared_expert(x_aD)
routed_out_egg_D = self.experts(routed_in_EG_D.detach()) routed_out_eg_D = self.experts(routed_in_EG_D.detach())
router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D) router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
out_aD.scatter_add_( out_aD.scatter_add_(
dim=0, dim=0,
index=router_indices_EG_D, index=router_indices_EG_D,
src=routed_out_egg_D.view(-1, D), src=routed_out_eg_D.view(-1, D),
) )
out_aD = reduce_from_model_parallel_region(out_aD) out_aD = reduce_from_model_parallel_region(out_aD)
return out_aD.view(-1, slen, D) return out_aD.view(-1, slen, D)

View file

@ -4,20 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap import textwrap
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from llama_stack.models.llama.datatypes import RawMediaItem, RawMessage, RawTextItem from ..datatypes import RawMediaItem, RawMessage, RawTextItem
from llama_stack.models.llama.prompt_format import ( from ..prompt_format import (
Llama4UseCase, Llama4UseCase,
TextCompletionContent, TextCompletionContent,
UseCase, UseCase,

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -6,20 +6,29 @@
import logging import logging
import os import os
from typing import Optional from typing import Callable, Optional
import torch import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
from torch import Tensor from torch import Tensor, nn
from torch.nn import functional as F from torch.nn import functional as F
from ..generation import QuantizationMode from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock from ..model import Transformer, TransformerBlock
from ..moe import MoE from ..moe import MoE
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def swiglu_wrapper_no_reduce(
self,
x: Tensor,
):
from ...quantize_impls import ffn_swiglu
return ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
def experts_batched_swiglu_wrapper( def experts_batched_swiglu_wrapper(
self, self,
x: Tensor, # (e, g, D) x: Tensor, # (e, g, D)
@ -51,24 +60,30 @@ def convert_to_quantized_model(
rank = get_model_parallel_rank() rank = get_model_parallel_rank()
def should_quantize_block(block: nn.Module) -> bool:
if not isinstance(block, TransformerBlock):
return False
is_moe = isinstance(block.feed_forward, MoE)
if quantization_mode == QuantizationMode.fp8_mixed:
# skip quantization on first and last layers
return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
return is_moe
use_rich_progress = use_rich_progress and rank == 0 use_rich_progress = use_rich_progress and rank == 0
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model) progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block)
if quantization_mode == QuantizationMode.int4_mixed: if quantization_mode == QuantizationMode.int4_mixed:
int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt") int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
int4_zero_points_path = os.path.join(checkpoint_dir, f"int4_zero_points_{rank}.pt")
if os.path.isfile(int4_scales_path): if os.path.isfile(int4_scales_path):
log_status(f"Rank {rank}: Loading int4 scales") log_status(f"Rank {rank}: Loading int4 scales")
int4_scales = torch.load(int4_scales_path, weights_only=True) int4_scales = torch.load(int4_scales_path, weights_only=True)
int4_zero_points = torch.load(int4_zero_points_path, weights_only=True)
def apply_quantization(key, weight): def apply_quantization(key, weight):
scale = int4_scales[key] scale = int4_scales[key]
zero_point = int4_zero_points[key]
return load_int4( return load_int4(
weight, weight,
scale, scale,
zero_point,
fp8_activation_scale_ub,
output_device=torch.device("cuda"), output_device=torch.device("cuda"),
) )
@ -77,6 +92,7 @@ def convert_to_quantized_model(
def apply_quantization(_, weight): def apply_quantization(_, weight):
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda")) return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
else: else:
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt") fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
if os.path.isfile(fp8_scales_path): if os.path.isfile(fp8_scales_path):
@ -104,33 +120,38 @@ def convert_to_quantized_model(
progress.start() progress.start()
for _, block in model.named_modules(): for _, block in model.named_modules():
if isinstance(block, TransformerBlock): if not should_quantize_block(block):
# Skip quantization on first and last layers continue
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
# Skip quantization on dense layers update_status(f"Rank {rank} - Layer {block.layer_id}")
if not isinstance(block.feed_forward, MoE):
continue
update_status(f"Rank {rank} - Layer {block.layer_id}") # Quantize only routed experts, not shared
prefix = f"layers.{block.layer_id}.feed_forward"
moe = block.feed_forward
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
# Quantize only routed experts, not shared for key in ("w1", "w3", "w2"):
prefix = f"layers.{block.layer_id}.feed_forward" param = getattr(moe.experts, key)
moe = block.feed_forward update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts) setattr(
moe.experts,
key,
apply_quantization(
f"{prefix}.experts.{key}",
param.transpose(1, 2).contiguous(),
),
)
if quantization_mode == QuantizationMode.int4_mixed:
# Quantize shared experts
moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert)
for key in ("w1", "w3", "w2"): for key in ("w1", "w3", "w2"):
param = getattr(moe.experts, key) param = getattr(moe.shared_expert, key)
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}") update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}")
setattr( param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight)
moe.experts,
key,
apply_quantization(f"{prefix}.experts.{key}", param.transpose(1, 2).contiguous()),
)
processed_blocks += 1 processed_blocks += 1
update_status(message=None, completed=processed_blocks) update_status(message=None, completed=processed_blocks)
update_status(f"Rank {rank} - Moving parameters to CUDA") update_status(f"Rank {rank} - Moving parameters to CUDA")
@ -149,7 +170,12 @@ def convert_to_quantized_model(
# fp8/int4 loading can be very slow so we add progress bars to make life slightly better # fp8/int4 loading can be very slow so we add progress bars to make life slightly better
def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer): def logging_callbacks(
use_rich_progress: bool,
rank: int,
model: Transformer,
should_quantize_block: Callable[[nn.Module], bool],
):
console = None console = None
if use_rich_progress: if use_rich_progress:
from rich.console import Console from rich.console import Console
@ -162,15 +188,7 @@ def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer):
elif rank == 0: # Only log from rank 0 for non-rich logging elif rank == 0: # Only log from rank 0 for non-rich logging
log.info(message) log.info(message)
total_blocks = sum( total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
1
for _, block in model.named_modules()
if (
isinstance(block, TransformerBlock)
and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
and isinstance(block.feed_forward, MoE)
)
)
progress = None progress = None
if use_rich_progress: if use_rich_progress:
from rich.progress import ( from rich.progress import (

View file

@ -4,9 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os import os
from logging import getLogger from logging import getLogger
from pathlib import Path from pathlib import Path
@ -59,8 +56,6 @@ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
"<|text_post_train_reserved_special_token_3|>", "<|text_post_train_reserved_special_token_3|>",
"<|text_post_train_reserved_special_token_4|>", "<|text_post_train_reserved_special_token_4|>",
"<|text_post_train_reserved_special_token_5|>", "<|text_post_train_reserved_special_token_5|>",
"<|python_start|>",
"<|python_end|>",
"<|finetune_right_pad|>", "<|finetune_right_pad|>",
] + get_reserved_special_tokens( ] + get_reserved_special_tokens(
"text_post_train", 61, 6 "text_post_train", 61, 6
@ -85,8 +80,23 @@ LLAMA4_VISION_SPECIAL_TOKENS = [
"vision", 1041, 7 "vision", 1041, 7
) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|> ) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|>
# 201134, ..., 201143
LLAMA4_REASONING_SPECIAL_TOKENS = [
"<|reasoning_reserved_special_token_0|>",
"<|reasoning_reserved_special_token_1|>",
"<|reasoning_reserved_special_token_2|>",
"<|reasoning_reserved_special_token_3|>",
"<|reasoning_reserved_special_token_4|>",
"<|reasoning_reserved_special_token_5|>",
"<|reasoning_reserved_special_token_6|>",
"<|reasoning_reserved_special_token_7|>",
"<|reasoning_thinking_start|>",
"<|reasoning_thinking_end|>",
]
LLAMA4_SPECIAL_TOKENS = LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS LLAMA4_SPECIAL_TOKENS = (
LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS + LLAMA4_REASONING_SPECIAL_TOKENS
)
BASIC_SPECIAL_TOKENS = [ BASIC_SPECIAL_TOKENS = [
"<|begin_of_text|>", "<|begin_of_text|>",
@ -155,6 +165,9 @@ class Tokenizer:
self.eot_id: int = self.special_tokens["<|eot|>"] self.eot_id: int = self.special_tokens["<|eot|>"]
self.eom_id: int = self.special_tokens["<|eom|>"] self.eom_id: int = self.special_tokens["<|eom|>"]
self.thinking_start_id: int = self.special_tokens["<|reasoning_thinking_start|>"]
self.thinking_end_id: int = self.special_tokens["<|reasoning_thinking_end|>"]
self.stop_tokens = [ self.stop_tokens = [
self.eos_id, self.eos_id,
self.special_tokens["<|eom|>"], self.special_tokens["<|eom|>"],

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import math import math
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List

View file

@ -28,9 +28,6 @@ from llama_stack.models.llama.datatypes import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.models.llama.llama4.tokenizer import Tokenizer from llama_stack.models.llama.llama4.tokenizer import Tokenizer
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
LLMInput,
)
from .llama3.interface import LLama31Interface from .llama3.interface import LLama31Interface
from .llama3.template_data import ( from .llama3.template_data import (
@ -76,21 +73,22 @@ class UseCase(BaseModel):
text += dialog text += dialog
text += "\n\n" text += "\n\n"
continue continue
elif isinstance(dialog, TextCompletionContent):
input_tokens, output_tokens = generator.text_completion_raw(
dialog.content,
temperature=0.1,
top_p=0.95,
max_gen_len=64,
)
else: else:
input_tokens, output_tokens = generator.chat_completion_raw( batch = [dialog]
dialog, method = (
temperature=0.0, generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
top_p=0.95,
max_gen_len=self.max_gen_len,
) )
input_tokens = []
output_tokens = []
for token_results in method(batch, echo=True, temperature=0.1, top_p=0.95):
result = token_results[0]
if result.source == "input":
input_tokens.append(result.token)
else:
output_tokens.append(result.token)
if result.finished:
break
text += "##### Input Prompt Format\n" text += "##### Input Prompt Format\n"
# FIXME: This is added to undo the hack in chat_formatter where # FIXME: This is added to undo the hack in chat_formatter where
@ -126,27 +124,27 @@ class Llama4UseCase(UseCase):
text = "" text = ""
tokenizer = Tokenizer.get_instance() tokenizer = Tokenizer.get_instance()
temperature = 0.0
for dialog in self.dialogs: for dialog in self.dialogs:
if isinstance(dialog, str): if isinstance(dialog, str):
text += dialog text += dialog
text += "\n\n" text += "\n\n"
continue continue
elif isinstance(dialog, TextCompletionContent):
# TODO pass the raw input and do the encoding in the text completion function
input_tokens = tokenizer.encode(dialog.content, bos=True, eos=False)
llm_input = LLMInput(tokens=input_tokens)
output_tokens, decoded_tokens, token_logprobs = generator.text_completion_raw(
llm_input, temperature=temperature, max_gen_len=self.max_gen_len
)
else: else:
input_tokens, output_tokens = generator.chat_completion_raw( batch = [dialog]
dialog, method = (
temperature=temperature, generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
max_gen_len=self.max_gen_len,
) )
input_tokens = []
output_tokens = []
for token_results in method(batch, echo=True, temperature=0.0):
result = token_results[0]
if result.source == "input":
input_tokens.append(result.token)
else:
output_tokens.append(result.token)
if result.finished:
break
text += "##### Input Prompt Format\n" text += "##### Input Prompt Format\n"
text += _code_block(tokenizer.decode(input_tokens)) text += _code_block(tokenizer.decode(input_tokens))

View file

@ -4,24 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import List, Optional from typing import List, Optional
from .datatypes import ( from .sku_types import (
CheckpointQuantizationFormat, CheckpointQuantizationFormat,
CoreModelId, CoreModelId,
Model, Model,
ModelFamily, ModelFamily,
SamplingParams,
TopPSamplingStrategy,
) )
LLAMA2_VOCAB_SIZE = 32000 LLAMA2_VOCAB_SIZE = 32000
@ -47,15 +38,6 @@ def all_registered_models() -> List[Model]:
) )
def recommended_sampling_params() -> SamplingParams:
return SamplingParams(
strategy=TopPSamplingStrategy(
temperature=1.0,
top_p=0.9,
)
)
def llama2_family() -> List[Model]: def llama2_family() -> List[Model]:
return [ return [
*llama2_base_models(), *llama2_base_models(),
@ -150,7 +132,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_7b, core_model_id=CoreModelId.llama2_7b,
description="Llama 2 7b model", description="Llama 2 7b model",
huggingface_repo="meta-llama/Llama-2-7b", huggingface_repo="meta-llama/Llama-2-7b",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 4096, "dim": 4096,
"n_layers": 32, "n_layers": 32,
@ -169,7 +150,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_13b, core_model_id=CoreModelId.llama2_13b,
description="Llama 2 13b model", description="Llama 2 13b model",
huggingface_repo="meta-llama/Llama-2-13b", huggingface_repo="meta-llama/Llama-2-13b",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 5120, "dim": 5120,
"n_layers": 40, "n_layers": 40,
@ -188,7 +168,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_70b, core_model_id=CoreModelId.llama2_70b,
description="Llama 2 70b model", description="Llama 2 70b model",
huggingface_repo="meta-llama/Llama-2-70b", huggingface_repo="meta-llama/Llama-2-70b",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -230,7 +209,6 @@ def llama3_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_70b, core_model_id=CoreModelId.llama3_70b,
description="Llama 3 70b model", description="Llama 3 70b model",
huggingface_repo="meta-llama/Llama-3-70B", huggingface_repo="meta-llama/Llama-3-70B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -254,7 +232,6 @@ def llama3_1_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_8b, core_model_id=CoreModelId.llama3_1_8b,
description="Llama 3.1 8b model", description="Llama 3.1 8b model",
huggingface_repo="meta-llama/Llama-3.1-8B", huggingface_repo="meta-llama/Llama-3.1-8B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 4096, "dim": 4096,
"n_layers": 32, "n_layers": 32,
@ -273,7 +250,6 @@ def llama3_1_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_70b, core_model_id=CoreModelId.llama3_1_70b,
description="Llama 3.1 70b model", description="Llama 3.1 70b model",
huggingface_repo="meta-llama/Llama-3.1-70B", huggingface_repo="meta-llama/Llama-3.1-70B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -293,7 +269,6 @@ def llama3_1_base_models() -> List[Model]:
variant="bf16-mp8", variant="bf16-mp8",
description="Llama 3.1 405b model (BF16 weights)", description="Llama 3.1 405b model (BF16 weights)",
huggingface_repo="meta-llama/Llama-3.1-405B", huggingface_repo="meta-llama/Llama-3.1-405B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 16384, "dim": 16384,
"n_layers": 126, "n_layers": 126,
@ -313,7 +288,6 @@ def llama3_1_base_models() -> List[Model]:
description="Llama 3.1 405b model (FP8 quantized)", description="Llama 3.1 405b model (FP8 quantized)",
huggingface_repo="meta-llama/Llama-3.1-405B-FP8", huggingface_repo="meta-llama/Llama-3.1-405B-FP8",
quantization_format=CheckpointQuantizationFormat.fp8_mixed, quantization_format=CheckpointQuantizationFormat.fp8_mixed,
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 16384, "dim": 16384,
"n_layers": 126, "n_layers": 126,
@ -333,7 +307,6 @@ def llama3_1_base_models() -> List[Model]:
variant="bf16-mp16", variant="bf16-mp16",
description="Llama 3.1 405b model (BF16 weights for mp16)", description="Llama 3.1 405b model (BF16 weights for mp16)",
huggingface_repo="meta-llama/Llama-3.1-405B", huggingface_repo="meta-llama/Llama-3.1-405B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 16384, "dim": 16384,
"n_layers": 126, "n_layers": 126,
@ -357,7 +330,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_1b, core_model_id=CoreModelId.llama3_2_1b,
description="Llama 3.2 1b model", description="Llama 3.2 1b model",
huggingface_repo="meta-llama/Llama-3.2-1B", huggingface_repo="meta-llama/Llama-3.2-1B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 2048, "dim": 2048,
"n_layers": 16, "n_layers": 16,
@ -376,7 +348,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_3b, core_model_id=CoreModelId.llama3_2_3b,
description="Llama 3.2 3b model", description="Llama 3.2 3b model",
huggingface_repo="meta-llama/Llama-3.2-3B", huggingface_repo="meta-llama/Llama-3.2-3B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 3072, "dim": 3072,
"n_layers": 28, "n_layers": 28,
@ -395,7 +366,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_11b_vision, core_model_id=CoreModelId.llama3_2_11b_vision,
description="Llama 3.2 11b vision model", description="Llama 3.2 11b vision model",
huggingface_repo="meta-llama/Llama-3.2-11B-Vision", huggingface_repo="meta-llama/Llama-3.2-11B-Vision",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 4096, "dim": 4096,
"n_layers": 32, "n_layers": 32,
@ -417,7 +387,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_90b_vision, core_model_id=CoreModelId.llama3_2_90b_vision,
description="Llama 3.2 90b vision model", description="Llama 3.2 90b vision model",
huggingface_repo="meta-llama/Llama-3.2-90B-Vision", huggingface_repo="meta-llama/Llama-3.2-90B-Vision",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -444,7 +413,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_7b_chat, core_model_id=CoreModelId.llama2_7b_chat,
description="Llama 2 7b chat model", description="Llama 2 7b chat model",
huggingface_repo="meta-llama/Llama-2-7b-chat", huggingface_repo="meta-llama/Llama-2-7b-chat",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 4096, "dim": 4096,
"n_layers": 32, "n_layers": 32,
@ -463,7 +431,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_13b_chat, core_model_id=CoreModelId.llama2_13b_chat,
description="Llama 2 13b chat model", description="Llama 2 13b chat model",
huggingface_repo="meta-llama/Llama-2-13b-chat", huggingface_repo="meta-llama/Llama-2-13b-chat",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 5120, "dim": 5120,
"n_layers": 40, "n_layers": 40,
@ -482,7 +449,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_70b_chat, core_model_id=CoreModelId.llama2_70b_chat,
description="Llama 2 70b chat model", description="Llama 2 70b chat model",
huggingface_repo="meta-llama/Llama-2-70b-chat", huggingface_repo="meta-llama/Llama-2-70b-chat",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -506,7 +472,6 @@ def llama3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_8b_instruct, core_model_id=CoreModelId.llama3_8b_instruct,
description="Llama 3 8b instruct model", description="Llama 3 8b instruct model",
huggingface_repo="meta-llama/Llama-3-8B-Instruct", huggingface_repo="meta-llama/Llama-3-8B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 4096, "dim": 4096,
"n_layers": 32, "n_layers": 32,
@ -525,7 +490,6 @@ def llama3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_70b_instruct, core_model_id=CoreModelId.llama3_70b_instruct,
description="Llama 3 70b instruct model", description="Llama 3 70b instruct model",
huggingface_repo="meta-llama/Llama-3-70B-Instruct", huggingface_repo="meta-llama/Llama-3-70B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -549,7 +513,6 @@ def llama3_1_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_8b_instruct, core_model_id=CoreModelId.llama3_1_8b_instruct,
description="Llama 3.1 8b instruct model", description="Llama 3.1 8b instruct model",
huggingface_repo="meta-llama/Llama-3.1-8B-Instruct", huggingface_repo="meta-llama/Llama-3.1-8B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 4096, "dim": 4096,
"n_layers": 32, "n_layers": 32,
@ -568,7 +531,6 @@ def llama3_1_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_70b_instruct, core_model_id=CoreModelId.llama3_1_70b_instruct,
description="Llama 3.1 70b instruct model", description="Llama 3.1 70b instruct model",
huggingface_repo="meta-llama/Llama-3.1-70B-Instruct", huggingface_repo="meta-llama/Llama-3.1-70B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -588,7 +550,6 @@ def llama3_1_instruct_models() -> List[Model]:
variant="bf16-mp8", variant="bf16-mp8",
description="Llama 3.1 405b instruct model (BF16 weights)", description="Llama 3.1 405b instruct model (BF16 weights)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct", huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 16384, "dim": 16384,
"n_layers": 126, "n_layers": 126,
@ -608,7 +569,6 @@ def llama3_1_instruct_models() -> List[Model]:
description="Llama 3.1 405b instruct model (FP8 quantized)", description="Llama 3.1 405b instruct model (FP8 quantized)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8", huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8",
quantization_format=CheckpointQuantizationFormat.fp8_mixed, quantization_format=CheckpointQuantizationFormat.fp8_mixed,
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 16384, "dim": 16384,
"n_layers": 126, "n_layers": 126,
@ -628,7 +588,6 @@ def llama3_1_instruct_models() -> List[Model]:
variant="bf16-mp16", variant="bf16-mp16",
description="Llama 3.1 405b instruct model (BF16 weights for mp16)", description="Llama 3.1 405b instruct model (BF16 weights for mp16)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct", huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 16384, "dim": 16384,
"n_layers": 126, "n_layers": 126,
@ -684,7 +643,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4, quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 1b INT4 quantized LoRA", description="Llama 3.2 1b INT4 quantized LoRA",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8", huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
**arch_args_1b(), **arch_args_1b(),
"quantization_args": { "quantization_args": {
@ -703,7 +661,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4, quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 1b INT4 quantized SpinQuant", description="Llama 3.2 1b INT4 quantized SpinQuant",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8", huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
**arch_args_1b(), **arch_args_1b(),
"quantization_args": { "quantization_args": {
@ -718,7 +675,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4, quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 3b INT4 quantized LoRA", description="Llama 3.2 3b INT4 quantized LoRA",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8", huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
**arch_args_3b(), **arch_args_3b(),
"quantization_args": { "quantization_args": {
@ -737,7 +693,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4, quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 3b INT4 quantized SpinQuant", description="Llama 3.2 3b INT4 quantized SpinQuant",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8", huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
**arch_args_3b(), **arch_args_3b(),
"quantization_args": { "quantization_args": {
@ -755,7 +710,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_1b_instruct, core_model_id=CoreModelId.llama3_2_1b_instruct,
description="Llama 3.2 1b instruct model", description="Llama 3.2 1b instruct model",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct", huggingface_repo="meta-llama/Llama-3.2-1B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args=arch_args_1b(), arch_args=arch_args_1b(),
pth_file_count=1, pth_file_count=1,
), ),
@ -763,7 +717,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_3b_instruct, core_model_id=CoreModelId.llama3_2_3b_instruct,
description="Llama 3.2 3b instruct model", description="Llama 3.2 3b instruct model",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct", huggingface_repo="meta-llama/Llama-3.2-3B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args=arch_args_3b(), arch_args=arch_args_3b(),
pth_file_count=1, pth_file_count=1,
), ),
@ -772,7 +725,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_11b_vision_instruct, core_model_id=CoreModelId.llama3_2_11b_vision_instruct,
description="Llama 3.2 11b vision instruct model", description="Llama 3.2 11b vision instruct model",
huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct", huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 4096, "dim": 4096,
"n_layers": 32, "n_layers": 32,
@ -794,7 +746,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_90b_vision_instruct, core_model_id=CoreModelId.llama3_2_90b_vision_instruct,
description="Llama 3.2 90b vision instruct model", description="Llama 3.2 90b vision instruct model",
huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct", huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -821,7 +772,6 @@ def llama3_3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_3_70b_instruct, core_model_id=CoreModelId.llama3_3_70b_instruct,
description="Llama 3.3 70b instruct", description="Llama 3.3 70b instruct",
huggingface_repo="meta-llama/Llama-3.3-70B-Instruct", huggingface_repo="meta-llama/Llama-3.3-70B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -846,7 +796,6 @@ def safety_models() -> List[Model]:
core_model_id=CoreModelId.llama_guard_3_11b_vision, core_model_id=CoreModelId.llama_guard_3_11b_vision,
description="Llama Guard v3 11b vision system safety model", description="Llama Guard v3 11b vision system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision", huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 4096, "dim": 4096,
"n_layers": 32, "n_layers": 32,
@ -870,7 +819,6 @@ def safety_models() -> List[Model]:
description="Llama Guard v3 1b 'int4' quantized system safety model", description="Llama Guard v3 1b 'int4' quantized system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4", huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4",
quantization_format=CheckpointQuantizationFormat.int4, quantization_format=CheckpointQuantizationFormat.int4,
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 2048, "dim": 2048,
"n_layers": 12, "n_layers": 12,
@ -888,7 +836,6 @@ def safety_models() -> List[Model]:
core_model_id=CoreModelId.llama_guard_3_1b, core_model_id=CoreModelId.llama_guard_3_1b,
description="Llama Guard v3 1b system safety model", description="Llama Guard v3 1b system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-1B", huggingface_repo="meta-llama/Llama-Guard-3-1B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 2048, "dim": 2048,
"n_layers": 16, "n_layers": 16,

View file

@ -0,0 +1,229 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict, Field
class CheckpointQuantizationFormat(Enum):
# default format
bf16 = "bf16"
# used for enabling fp8_rowwise inference, some weights are bf16
fp8_mixed = "fp8-mixed"
int8 = "int8"
int4 = "int4"
class ModelFamily(Enum):
llama2 = "llama2"
llama3 = "llama3"
llama3_1 = "llama3_1"
llama3_2 = "llama3_2"
llama3_3 = "llama3_3"
llama4 = "llama4"
safety = "safety"
class CoreModelId(Enum):
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
# Llama 2 family
llama2_7b = "Llama-2-7b"
llama2_13b = "Llama-2-13b"
llama2_70b = "Llama-2-70b"
llama2_7b_chat = "Llama-2-7b-chat"
llama2_13b_chat = "Llama-2-13b-chat"
llama2_70b_chat = "Llama-2-70b-chat"
# Llama 3 family
llama3_8b = "Llama-3-8B"
llama3_70b = "Llama-3-70B"
llama3_8b_instruct = "Llama-3-8B-Instruct"
llama3_70b_instruct = "Llama-3-70B-Instruct"
# Llama 3.1 family
llama3_1_8b = "Llama3.1-8B"
llama3_1_70b = "Llama3.1-70B"
llama3_1_405b = "Llama3.1-405B"
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
# Llama 3.2 family
llama3_2_1b = "Llama3.2-1B"
llama3_2_3b = "Llama3.2-3B"
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
llama3_2_11b_vision = "Llama3.2-11B-Vision"
llama3_2_90b_vision = "Llama3.2-90B-Vision"
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
# Llama 3.3 family
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
# Llama 4 family
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
# Safety models
llama_guard_3_8b = "Llama-Guard-3-8B"
llama_guard_2_8b = "Llama-Guard-2-8B"
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
llama_guard_3_1b = "Llama-Guard-3-1B"
def is_multimodal(model_id) -> bool:
if model_id in [
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return True
else:
return False
def model_family(model_id) -> ModelFamily:
if model_id in [
CoreModelId.llama2_7b,
CoreModelId.llama2_13b,
CoreModelId.llama2_70b,
CoreModelId.llama2_7b_chat,
CoreModelId.llama2_13b_chat,
CoreModelId.llama2_70b_chat,
]:
return ModelFamily.llama2
elif model_id in [
CoreModelId.llama3_8b,
CoreModelId.llama3_70b,
CoreModelId.llama3_8b_instruct,
CoreModelId.llama3_70b_instruct,
]:
return ModelFamily.llama3
elif model_id in [
CoreModelId.llama3_1_8b,
CoreModelId.llama3_1_70b,
CoreModelId.llama3_1_405b,
CoreModelId.llama3_1_8b_instruct,
CoreModelId.llama3_1_70b_instruct,
CoreModelId.llama3_1_405b_instruct,
]:
return ModelFamily.llama3_1
elif model_id in [
CoreModelId.llama3_2_1b,
CoreModelId.llama3_2_3b,
CoreModelId.llama3_2_1b_instruct,
CoreModelId.llama3_2_3b_instruct,
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return ModelFamily.llama3_2
elif model_id in [
CoreModelId.llama3_3_70b_instruct,
]:
return ModelFamily.llama3_3
elif model_id in [
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_scout_17b_16e_instruct,
CoreModelId.llama4_maverick_17b_128e,
CoreModelId.llama4_maverick_17b_128e_instruct,
]:
return ModelFamily.llama4
elif model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_2_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return ModelFamily.safety
else:
raise ValueError(f"Unknown model family for {model_id}")
class Model(BaseModel):
core_model_id: CoreModelId
description: str
huggingface_repo: Optional[str] = None
arch_args: Dict[str, Any]
variant: str = ""
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
pth_file_count: int
metadata: Dict[str, Any] = Field(default_factory=dict)
# silence pydantic until we remove the `model_` fields
model_config = ConfigDict(protected_namespaces=())
@property
def model_family(self) -> ModelFamily:
return model_family(self.core_model_id)
# The SKU is uniquely identified by (model_id, variant) combo
def descriptor(self, shorten_default_variant: bool = True) -> str:
if not self.variant:
return self.core_model_id.value
return f"{self.core_model_id.value}:{self.variant}"
@property
def is_instruct_model(self) -> bool:
return "instruct" in self.core_model_id.value
# Featured models are shown in the non-exhaustive model list
@property
def is_featured(self) -> bool:
return self.model_family in [
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
ModelFamily.safety,
]
@property
def max_seq_length(self) -> int:
if self.model_family == ModelFamily.llama2:
return 4096
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
return 4096
elif self.model_family == ModelFamily.llama3:
return 8192
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
return 131072
elif self.model_family == ModelFamily.llama3_2:
if self.quantization_format == CheckpointQuantizationFormat.int4:
return 8192
return 131072
elif self.model_family == ModelFamily.llama4:
if self.core_model_id in {
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_maverick_17b_128e,
}:
return 262144
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
return 10485760
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
return 1048576
raise AssertionError(f"Unexpected core model id: {self.core_model_id}")
elif self.core_model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return 131072
else:
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")

View file

@ -52,6 +52,7 @@ from llama_stack.apis.inference import (
StopReason, StopReason,
SystemMessage, SystemMessage,
ToolDefinition, ToolDefinition,
ToolParamDefinition,
ToolResponse, ToolResponse,
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
@ -63,7 +64,6 @@ from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
ToolCall, ToolCall,
ToolParamDefinition,
) )
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.telemetry import tracing from llama_stack.providers.utils.telemetry import tracing

View file

@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, Union from typing import Any, Dict
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig from .config import MetaReferenceInferenceConfig
async def get_provider_impl( async def get_provider_impl(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig], config: MetaReferenceInferenceConfig,
_deps: Dict[str, Any], _deps: Dict[str, Any],
): ):
from .inference import MetaReferenceInferenceImpl from .inference import MetaReferenceInferenceImpl

View file

@ -5,19 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
from pathlib import Path from pathlib import Path
from typing import List, Optional
from pydantic import BaseModel
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
class TokenResult(BaseModel):
token: int
text: str
logprobs: Optional[List[float]] = None
def model_checkpoint_dir(model_id) -> str: def model_checkpoint_dir(model_id) -> str:
checkpoint_dir = Path(model_local_dir(model_id)) checkpoint_dir = Path(model_local_dir(model_id))

View file

@ -21,6 +21,7 @@ class MetaReferenceInferenceConfig(BaseModel):
torch_seed: Optional[int] = None torch_seed: Optional[int] = None
max_seq_len: int = 4096 max_seq_len: int = 4096
max_batch_size: int = 1 max_batch_size: int = 1
model_parallel_size: Optional[int] = None
# when this is False, we assume that the distributed process group is setup by someone # when this is False, we assume that the distributed process group is setup by someone
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients # outside of this code (e.g., when run inside `torchrun`). that is useful for clients
@ -31,6 +32,8 @@ class MetaReferenceInferenceConfig(BaseModel):
# can override by specifying the directory explicitly # can override by specifying the directory explicitly
checkpoint_dir: Optional[str] = None checkpoint_dir: Optional[str] = None
quantization: Optional[QuantizationConfig] = None
@field_validator("model") @field_validator("model")
@classmethod @classmethod
def validate_model(cls, model: str) -> str: def validate_model(cls, model: str) -> str:
@ -47,27 +50,16 @@ class MetaReferenceInferenceConfig(BaseModel):
cls, cls,
model: str = "Llama3.2-3B-Instruct", model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}", checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return { return {
"model": model, "model": model,
"max_seq_len": 4096, "max_seq_len": 4096,
"checkpoint_dir": checkpoint_dir, "checkpoint_dir": checkpoint_dir,
"quantization": {
"type": quantization_type,
},
"model_parallel_size": model_parallel_size,
} }
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
quantization: QuantizationConfig
@classmethod
def sample_run_config(
cls,
model: str = "Llama3.2-3B-Instruct",
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
**kwargs,
) -> Dict[str, Any]:
config = super().sample_run_config(model, checkpoint_dir, **kwargs)
config["quantization"] = {
"type": "fp8",
}
return config

View file

@ -11,19 +11,18 @@ import torch
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Fp8QuantizationConfig, GreedySamplingStrategy,
Int4QuantizationConfig,
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
ResponseFormat, ResponseFormat,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
Model,
SamplingParams, SamplingParams,
TopPSamplingStrategy, TopPSamplingStrategy,
) )
from llama_stack.models.llama.datatypes import QuantizationMode
from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_types import Model
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent, ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent, CompletionRequestWithRawContent,
@ -31,10 +30,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
) )
from .common import model_checkpoint_dir from .common import model_checkpoint_dir
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig from .config import MetaReferenceInferenceConfig
from .inference import resolve_model from .inference import resolve_model
from .llama3.generation import Llama3
from .llama4.generation import Llama4
Tokenizer = Llama4Tokenizer | Llama3Tokenizer Tokenizer = Llama4Tokenizer | Llama3Tokenizer
@ -116,10 +113,11 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
return get_default_tool_prompt_format(request.model) return get_default_tool_prompt_format(request.model)
# TODO: combine Llama3 and Llama4 generators since they are almost identical now
class Llama4Generator: class Llama4Generator:
def __init__( def __init__(
self, self,
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig, config: MetaReferenceInferenceConfig,
model_id: str, model_id: str,
llama_model: Model, llama_model: Model,
): ):
@ -134,11 +132,13 @@ class Llama4Generator:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor()) ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
if isinstance(config, MetaReferenceQuantizedInferenceConfig): if config.quantization:
if isinstance(config.quantization, Fp8QuantizationConfig): if config.quantization.type == "fp8_mixed":
quantization_mode = "fp8_mixed" quantization_mode = QuantizationMode.fp8_mixed
elif isinstance(config.quantization, Int4QuantizationConfig): elif config.quantization.type == "int4_mixed":
quantization_mode = "int4_mixed" quantization_mode = QuantizationMode.int4_mixed
elif config.quantization.type == "bf16":
quantization_mode = None
else: else:
raise ValueError(f"Unsupported quantization mode {config.quantization}") raise ValueError(f"Unsupported quantization mode {config.quantization}")
else: else:
@ -148,7 +148,7 @@ class Llama4Generator:
ckpt_dir=ckpt_dir, ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len, max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size, max_batch_size=config.max_batch_size,
world_size=llama_model.pth_file_count, world_size=config.model_parallel_size or llama_model.pth_file_count,
quantization_mode=quantization_mode, quantization_mode=quantization_mode,
) )
@ -166,8 +166,8 @@ class Llama4Generator:
max_gen_len = self.args.max_seq_len - 1 max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params) temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate( for result in self.inner_generator.generate(
llm_input=self.formatter.encode_content(request.content), llm_inputs=[self.formatter.encode_content(request.content)],
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
@ -178,7 +178,8 @@ class Llama4Generator:
self.args.vocab_size, self.args.vocab_size,
request.response_format, request.response_format,
), ),
) ):
yield result[0]
def chat_completion( def chat_completion(
self, self,
@ -190,8 +191,8 @@ class Llama4Generator:
max_gen_len = self.args.max_seq_len - 1 max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params) temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate( for result in self.inner_generator.generate(
llm_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)), llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
@ -202,20 +203,46 @@ class Llama4Generator:
self.args.vocab_size, self.args.vocab_size,
request.response_format, request.response_format,
), ),
) ):
yield result[0]
class Llama3Generator: class Llama3Generator:
def __init__( def __init__(
self, self,
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig, config: MetaReferenceInferenceConfig,
model_id: str, model_id: str,
llama_model: Model, llama_model: Model,
): ):
if config.checkpoint_dir and config.checkpoint_dir != "null":
ckpt_dir = config.checkpoint_dir
else:
resolved_model = resolve_model(model_id)
if resolved_model is None:
# if the model is not a native llama model, get the default checkpoint_dir based on model id
ckpt_dir = model_checkpoint_dir(model_id)
else:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
if config.quantization:
if config.quantization.type == "fp8_mixed":
quantization_mode = QuantizationMode.fp8_mixed
elif config.quantization.type == "int4_mixed":
quantization_mode = QuantizationMode.int4_mixed
elif config.quantization.type == "bf16":
quantization_mode = None
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
quantization_mode = None
self.inner_generator = Llama3.build( self.inner_generator = Llama3.build(
config=config, ckpt_dir=ckpt_dir,
model_id=model_id, max_seq_len=config.max_seq_len,
llama_model=llama_model, max_batch_size=config.max_batch_size,
world_size=config.model_parallel_size or llama_model.pth_file_count,
quantization_mode=quantization_mode,
) )
self.tokenizer = self.inner_generator.tokenizer self.tokenizer = self.inner_generator.tokenizer
self.args = self.inner_generator.args self.args = self.inner_generator.args
@ -231,8 +258,8 @@ class Llama3Generator:
max_gen_len = self.args.max_seq_len - 1 max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params) temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate( for result in self.inner_generator.generate(
model_input=self.formatter.encode_content(request.content), llm_inputs=[self.formatter.encode_content(request.content)],
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
@ -243,7 +270,8 @@ class Llama3Generator:
self.args.vocab_size, self.args.vocab_size,
request.response_format, request.response_format,
), ),
) ):
yield result[0]
def chat_completion( def chat_completion(
self, self,
@ -255,8 +283,8 @@ class Llama3Generator:
max_gen_len = self.args.max_seq_len - 1 max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params) temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate( for result in self.inner_generator.generate(
model_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)), llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
@ -267,4 +295,5 @@ class Llama3Generator:
self.args.vocab_size, self.args.vocab_size,
request.response_format, request.response_format,
), ),
) ):
yield result[0]

View file

@ -31,23 +31,21 @@ from llama_stack.apis.inference import (
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams,
StopReason,
TokenLogProbs, TokenLogProbs,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import (
ModelFamily,
SamplingParams,
StopReason,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
@ -151,7 +149,7 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator( self.generator = LlamaModelParallelGenerator(
model_parallel_size=llama_model.pth_file_count, model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
builder_fn=builder_fn, builder_fn=builder_fn,
builder_params=builder_params, builder_params=builder_params,
formatter=( formatter=(

View file

@ -1,346 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
import sys
import time
from pathlib import Path
from typing import Callable, Generator, Optional, Union
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from llama_stack.apis.inference import (
Fp8QuantizationConfig,
Int4QuantizationConfig,
)
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.llama3.chat_format import ChatFormat, LLMInput
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from ..common import TokenResult, model_checkpoint_dir
from ..config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
from .args import ModelArgs
from .model import Transformer
from .multimodal.model import CrossAttentionTransformer
log = get_logger(__name__, category="inference")
class Llama3:
@staticmethod
def build(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
model_id: str,
llama_model: Model,
):
"""
Build a Llama instance by initializing and loading a model checkpoint.
Note:
This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer.
"""
if "DEVICE" in os.environ:
device = os.environ.get("DEVICE")
if device == "cuda":
assert torch.cuda.is_available(), "PyTorch CUDA backend not available"
if device == "xpu":
assert torch.xpu.is_available(), "PyTorch XPU backend not available"
else:
if torch.cuda.is_available():
device = "cuda"
elif torch.xpu.is_available():
device = "xpu"
else:
device = "cpu"
log.info(f"Using {device} device")
llama_model_id = llama_model.core_model_id.value
if not torch.distributed.is_initialized():
if device == "cuda":
torch.distributed.init_process_group("nccl")
else:
torch.distributed.init_process_group("gloo")
model_parallel_size = llama_model.pth_file_count
if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if device == "cuda":
torch.cuda.set_device(local_rank)
elif device == "xpu":
torch.xpu.set_device(local_rank)
# seed must be the same in all processes
if config.torch_seed is not None:
torch.manual_seed(config.torch_seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
if config.checkpoint_dir and config.checkpoint_dir != "null":
ckpt_dir = config.checkpoint_dir
else:
resolved_model = resolve_model(model_id)
if resolved_model is None:
# if the model is not a native llama model, get the default checkpoint_dir based on model id
ckpt_dir = model_checkpoint_dir(model_id)
else:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(checkpoints), (
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
)
ckpt_path = checkpoints[get_model_parallel_rank()]
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
if "model" in params:
params = params["model"]
model_args: ModelArgs = ModelArgs(
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
**params,
)
tokenizer = Tokenizer.get_instance()
assert model_args.vocab_size == tokenizer.n_words, (
f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
)
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
from .quantization.loader import convert_to_fp8_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
elif isinstance(config.quantization, Int4QuantizationConfig):
from .quantization.loader import convert_to_int4_quantized_model
model = Transformer(model_args)
model = convert_to_int4_quantized_model(model, model_args, config)
model.load_state_dict(state_dict, strict=True)
if model_args.quantization_args is not None and model_args.quantization_args.spinquant:
# Add a wrapper for adding hadamard transform for spinquant.
# This needs to be done after loading the state dict otherwise an error will be raised while
# loading the state dict.
from ..hadamard_utils import (
add_hadamard_transform_for_spinquant,
)
add_hadamard_transform_for_spinquant(model)
else:
raise NotImplementedError("Currently int4 and fp8 are the only supported quantization methods.")
else:
if device == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
else:
torch.set_default_device(device)
if device == "xpu" and torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
model.to(device)
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama3(model, tokenizer, model_args, llama_model_id)
def __init__(
self,
model: Transformer,
tokenizer: Tokenizer,
args: ModelArgs,
llama_model: str,
):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
self.llama_model = llama_model
@torch.inference_mode()
def generate(
self,
model_input: LLMInput,
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
print_input_tokens: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator:
params = self.model.params
if print_input_tokens:
input_tokens = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens]
log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
prompt_tokens = [model_input.tokens]
bsz = 1
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
log.error(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}")
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
is_vision = isinstance(self.model, CrossAttentionTransformer)
if is_vision:
images = model_input.vision.images if model_input.vision is not None else []
mask = model_input.vision.mask if model_input.vision is not None else []
# the method works for bsz > 1 so add a batch dimension
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=[images],
batch_masks=[mask],
total_len=total_len,
)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
if logprobs:
token_logprobs = torch.zeros_like(tokens)
prev_pos = 0
eos_reached = torch.tensor([False] * bsz)
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
logits = self.model.forward(tokens, prev_pos)
token_logprobs = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens,
reduction="none",
ignore_index=pad_id,
)
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
logits = self.model.forward(
position_ids,
tokens,
cross_attention_masks,
full_text_row_masked_out_mask,
xattn_caches,
)
else:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if logits_processor is not None:
logits = logits_processor(tokens[:, :cur_pos], logits)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if is_vision:
# the logits space (num_classes) is designed to never contain a media_token
# however our input token stream does contain them. we need to nuke them here
# or else the CUDA kernels will crash with an illegal memory access
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
masks = [target.eq(t) for t in vision_tokens]
if len(masks) > 1:
mask = torch.logical_or(*masks)
else:
mask = masks[0]
target[mask] = 0
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
yield TokenResult(
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
)
prev_pos = cur_pos
if all(eos_reached):
break
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

View file

@ -32,13 +32,12 @@ from pydantic import BaseModel, Field
from torch.distributed.launcher.api import LaunchConfig, elastic_launch from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.models.llama.datatypes import GenerationResult
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent, ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent, CompletionRequestWithRawContent,
) )
from .common import TokenResult
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -75,7 +74,7 @@ class TaskRequest(BaseModel):
class TaskResponse(BaseModel): class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
result: TokenResult result: GenerationResult
class ExceptionResponse(BaseModel): class ExceptionResponse(BaseModel):

View file

@ -14,9 +14,10 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
Message, Message,
ToolChoice, ToolChoice,
ToolDefinition,
UserMessage, UserMessage,
) )
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict, convert_message_to_openai_dict,
get_sampling_options, get_sampling_options,

View file

@ -46,6 +46,8 @@ from llama_stack.apis.inference import (
TokenLogProbs, TokenLogProbs,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
TopKSamplingStrategy,
TopPSamplingStrategy,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -55,8 +57,6 @@ from llama_stack.models.llama.datatypes import (
ToolCall, ToolCall,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
TopKSamplingStrategy,
TopPSamplingStrategy,
) )
from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer

View file

@ -22,8 +22,8 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform from torchtune.modules.transforms import Transform
from llama_stack.apis.post_training import DatasetFormat from llama_stack.apis.post_training import DatasetFormat
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import Model
BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer] BuildTokenizerCallable = Callable[..., Llama3Tokenizer]

View file

@ -23,7 +23,8 @@ from llama_stack.apis.safety import (
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.models.llama.datatypes import CoreModelId, Role from llama_stack.models.llama.datatypes import Role
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,

View file

@ -24,6 +24,8 @@ META_REFERENCE_DEPS = [
"zmq", "zmq",
"lm-format-enforcer", "lm-format-enforcer",
"sentence-transformers", "sentence-transformers",
"torchao==0.5.0",
"fbgemm-gpu-genai==1.1.2",
] ]
@ -36,13 +38,6 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.inference.meta_reference", module="llama_stack.providers.inline.inference.meta_reference",
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig", config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
), ),
InlineProviderSpec(
api=Api.inference,
provider_type="inline::meta-reference-quantized",
pip_packages=META_REFERENCE_DEPS + ["fbgemm-gpu", "torchao==0.5.0"],
module="llama_stack.providers.inline.inference.meta_reference",
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig",
),
InlineProviderSpec( InlineProviderSpec(
api=Api.inference, api=Api.inference,
provider_type="inline::vllm", provider_type="inline::vllm",

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )

View file

@ -28,8 +28,8 @@ from llama_stack.apis.inference import (
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
TopKSamplingStrategy,
) )
from llama_stack.models.llama.datatypes import TopKSamplingStrategy
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )

View file

@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_hf_repo_model_entry, build_hf_repo_model_entry,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.models.models import ModelType from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry, ProviderModelEntry,
build_hf_repo_model_entry, build_hf_repo_model_entry,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry, ProviderModelEntry,
build_hf_repo_model_entry, build_hf_repo_model_entry,

View file

@ -29,15 +29,13 @@ from llama_stack.apis.inference import (
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams,
TextTruncation, TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
)
from llama_stack.models.llama.datatypes import (
SamplingParams,
ToolDefinition, ToolDefinition,
ToolPromptFormat,
) )
from llama_stack.models.llama.datatypes import ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )

View file

@ -19,11 +19,9 @@ from llama_stack.apis.inference import (
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
GreedySamplingStrategy,
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
TokenLogProbs, TokenLogProbs,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy, TopKSamplingStrategy,
TopPSamplingStrategy, TopPSamplingStrategy,
) )

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.models.models import ModelType from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry, ProviderModelEntry,
build_hf_repo_model_entry, build_hf_repo_model_entry,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )

View file

@ -21,6 +21,7 @@ from llama_stack.apis.inference import (
CompletionMessage, CompletionMessage,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
GreedySamplingStrategy,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
@ -35,12 +36,9 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
ToolResponseMessage, ToolResponseMessage,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy, TopKSamplingStrategy,
TopPSamplingStrategy, TopPSamplingStrategy,
UserMessage,
) )
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.models.models import ModelType from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry, ProviderModelEntry,
build_hf_repo_model_entry, build_hf_repo_model_entry,

View file

@ -6,7 +6,7 @@
from typing import List from typing import List
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry, ProviderModelEntry,
build_hf_repo_model_entry, build_hf_repo_model_entry,

View file

@ -12,8 +12,8 @@ import pytest
from pytest import ExitCode from pytest import ExitCode
from pytest_html.basereport import _process_outcome from pytest_html.basereport import _process_outcome
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.models.llama.sku_types import CoreModelId
INFERENCE_APIS = ["chat_completion"] INFERENCE_APIS = ["chat_completion"]
FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"] FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"]

View file

@ -6,8 +6,8 @@
from typing import List from typing import List
from llama_stack.models.llama.datatypes import * # noqa: F403
from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.models.llama.sku_types import * # noqa: F403
def is_supported_safety_model(model: Model) -> bool: def is_supported_safety_model(model: Model) -> bool:

View file

@ -73,21 +73,21 @@ from llama_stack.apis.inference import (
CompletionMessage, CompletionMessage,
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
GreedySamplingStrategy,
Message, Message,
SamplingParams,
SystemMessage, SystemMessage,
TokenLogProbs, TokenLogProbs,
ToolResponseMessage, ToolResponseMessage,
TopKSamplingStrategy,
TopPSamplingStrategy,
UserMessage, UserMessage,
) )
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
GreedySamplingStrategy,
SamplingParams,
StopReason, StopReason,
ToolCall, ToolCall,
ToolDefinition, ToolDefinition,
TopKSamplingStrategy,
TopPSamplingStrategy,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url, convert_image_content_to_url,

View file

@ -34,7 +34,6 @@ from llama_stack.apis.inference import (
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
ModelFamily,
RawContent, RawContent,
RawContentItem, RawContentItem,
RawMediaItem, RawMediaItem,
@ -43,7 +42,6 @@ from llama_stack.models.llama.datatypes import (
Role, Role,
StopReason, StopReason,
ToolPromptFormat, ToolPromptFormat,
is_multimodal,
) )
from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import ( from llama_stack.models.llama.llama3.prompt_templates import (
@ -55,6 +53,7 @@ from llama_stack.models.llama.llama3.prompt_templates import (
) )
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
log = get_logger(name=__name__, category="inference") log = get_logger(name=__name__, category="inference")

View file

@ -356,50 +356,7 @@
"fairscale", "fairscale",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"fire", "fbgemm-gpu-genai==1.1.2",
"httpx",
"langdetect",
"lm-format-enforcer",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentence-transformers",
"sentencepiece",
"torch",
"torchvision",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"zmq"
],
"meta-reference-quantized-gpu": [
"accelerate",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fairscale",
"faiss-cpu",
"fastapi",
"fbgemm-gpu",
"fire", "fire",
"httpx", "httpx",
"langdetect", "langdetect",

View file

@ -18,6 +18,9 @@ providers:
model: ${env.INFERENCE_MODEL} model: ${env.INFERENCE_MODEL}
max_seq_len: 4096 max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization:
type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
@ -27,6 +30,9 @@ providers:
model: ${env.SAFETY_MODEL} model: ${env.SAFETY_MODEL}
max_seq_len: 4096 max_seq_len: 4096
checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null} checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null}
quantization:
type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
vector_io: vector_io:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss

View file

@ -18,6 +18,9 @@ providers:
model: ${env.INFERENCE_MODEL} model: ${env.INFERENCE_MODEL}
max_seq_len: 4096 max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization:
type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}

View file

@ -1,32 +0,0 @@
version: '2'
distribution_spec:
description: Use Meta Reference with fp8, int4 quantization for running LLM inference
providers:
inference:
- inline::meta-reference-quantized
vector_io:
- inline::faiss
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime
- remote::model-context-protocol
image_type: conda

View file

@ -1,113 +0,0 @@
---
orphan: true
---
# Meta Reference Quantized Distribution
```{toctree}
:maxdepth: 2
:hidden:
self
```
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations:
{{ providers_table }}
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.
Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs.
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
## Prerequisite: Downloading Models
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
```
$ llama model list --downloaded
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```
## Running the Distribution
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-{{ name }} \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
```
If you are using Llama Stack Safety / Shield APIs, use:
```bash
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-{{ name }} \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
```
### Via Conda
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
```bash
llama stack build --template {{ name }} --image-type conda
llama stack run distributions/{{ name }}/run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
```
If you are using Llama Stack Safety / Shield APIs, use:
```bash
llama stack run distributions/{{ name }}/run-with-safety.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
```

View file

@ -1,115 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupInput
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceQuantizedInferenceConfig,
)
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["inline::meta-reference-quantized"],
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::model-context-protocol",
],
}
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
name = "meta-reference-quantized-gpu"
inference_provider = Provider(
provider_id="meta-reference-inference",
provider_type="inline::meta-reference-quantized",
config=MetaReferenceQuantizedInferenceConfig.sample_run_config(
model="${env.INFERENCE_MODEL}",
checkpoint_dir="${env.INFERENCE_CHECKPOINT_DIR:null}",
),
)
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
vector_io_provider = Provider(
provider_id="faiss",
provider_type="inline::faiss",
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="meta-reference-inference",
)
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Use Meta Reference with fp8, int4 quantization for running LLM inference",
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider, embedding_provider],
"vector_io": [vector_io_provider],
},
default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"8321",
"Port for the Llama Stack distribution server",
),
"INFERENCE_MODEL": (
"meta-llama/Llama-3.2-3B-Instruct",
"Inference model loaded into the Meta Reference server",
),
"INFERENCE_CHECKPOINT_DIR": (
"null",
"Directory containing the Meta Reference model checkpoint",
),
},
)

View file

@ -1,134 +0,0 @@
version: '2'
image_name: meta-reference-quantized-gpu
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: meta-reference-inference
provider_type: inline::meta-reference-quantized
config:
model: ${env.INFERENCE_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization:
type: fp8
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-quantized-gpu/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-quantized-gpu}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
server:
port: 8321

View file

@ -224,9 +224,9 @@ exclude = [
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/config\\.py$", "^llama_stack/providers/inline/inference/meta_reference/config\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$", "^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/llama3/generation\\.py$", "^llama_stack/models/llama/llama3/generation\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/llama4/", "^llama_stack/models/llama/llama4/",
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$", "^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$",

View file

@ -5,13 +5,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Run this script: # Run this script:
# torchrun --nproc_per_node=8 scripts/generate_prompt_format.py meta-llama/Llama-4-17B-Omni-Instruct-BF16-16E ~/.llama/checkpoints/Llama-4-17B-Omni-Instruct-BF16-16E/ llama_stack.models.llama.llama4.prompts llama_stack/models/llama/llama4/prompt_format.md # torchrun --nproc_per_node=8 scripts/generate_prompt_format.py meta-llama/Llama-4-17B-Omni-Instruct-BF16-16E ~/.llama/checkpoints/Llama-4-17B-Omni-Instruct-BF16-16E/ llama_stack.models.llama.llama4.prompts llama_stack/models/llama/llama4/prompt_format.md
@ -22,16 +15,9 @@ from pathlib import Path
import fire import fire
from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.inference.meta_reference.config import (
MetaReferenceInferenceConfig,
)
from llama_stack.providers.inline.inference.meta_reference.llama3.generation import (
Llama3,
)
from llama_stack.providers.inline.inference.meta_reference.llama4.generation import (
Llama4,
)
THIS_DIR = Path(__file__).parent.resolve() THIS_DIR = Path(__file__).parent.resolve()
@ -50,24 +36,12 @@ def run_main(
if not llama_model: if not llama_model:
raise ValueError(f"Model {model_id} not found") raise ValueError(f"Model {model_id} not found")
if not llama4: cls = Llama4 if llama4 else Llama3
config = MetaReferenceInferenceConfig( generator = cls.build(
model=model_id, ckpt_dir=checkpoint_dir,
max_seq_len=4096, max_seq_len=4096,
max_batch_size=1, max_batch_size=1,
checkpoint_dir=checkpoint_dir, )
)
generator = Llama3.build(
config=config,
model_id=model_id,
llama_model=llama_model,
)
else:
generator = Llama4.build(
ckpt_dir=checkpoint_dir,
max_seq_len=4096,
max_batch_size=1,
)
use_cases = module.usecases() use_cases = module.usecases()
text = "" text = ""

View file

@ -11,7 +11,6 @@ import pytest
from pytest import CollectReport from pytest import CollectReport
from termcolor import cprint from termcolor import cprint
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_list import ( from llama_stack.models.llama.sku_list import (
all_registered_models, all_registered_models,
llama3_1_instruct_models, llama3_1_instruct_models,
@ -20,6 +19,7 @@ from llama_stack.models.llama.sku_list import (
llama3_instruct_models, llama3_instruct_models,
safety_models, safety_models,
) )
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from .metadata import API_MAPS from .metadata import API_MAPS