several fixes

This commit is contained in:
Ashwin Bharambe 2025-04-07 10:31:20 -07:00
parent e2e2820c9a
commit 53a8086e37
60 changed files with 1006 additions and 1078 deletions

View file

@ -4163,80 +4163,70 @@
] ]
}, },
"arguments": { "arguments": {
"oneOf": [ "type": "object",
{ "additionalProperties": {
"type": "string" "oneOf": [
}, {
{ "type": "string"
"type": "object", },
"additionalProperties": { {
"oneOf": [ "type": "integer"
{ },
"type": "string" {
}, "type": "number"
{ },
"type": "integer" {
}, "type": "boolean"
{ },
"type": "number" {
}, "type": "null"
{ },
"type": "boolean" {
}, "type": "array",
{ "items": {
"type": "null" "oneOf": [
}, {
{ "type": "string"
"type": "array", },
"items": { {
"oneOf": [ "type": "integer"
{ },
"type": "string" {
}, "type": "number"
{ },
"type": "integer" {
}, "type": "boolean"
{ },
"type": "number" {
}, "type": "null"
{
"type": "boolean"
},
{
"type": "null"
}
]
} }
}, ]
{ }
"type": "object", },
"additionalProperties": { {
"oneOf": [ "type": "object",
{ "additionalProperties": {
"type": "string" "oneOf": [
}, {
{ "type": "string"
"type": "integer" },
}, {
{ "type": "integer"
"type": "number" },
}, {
{ "type": "number"
"type": "boolean" },
}, {
{ "type": "boolean"
"type": "null" },
} {
] "type": "null"
} }
} ]
] }
} }
} ]
] }
},
"arguments_json": {
"type": "string"
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -2890,34 +2890,30 @@ components:
title: BuiltinTool title: BuiltinTool
- type: string - type: string
arguments: arguments:
oneOf: type: object
- type: string additionalProperties:
- type: object oneOf:
additionalProperties: - type: string
oneOf: - type: integer
- type: string - type: number
- type: integer - type: boolean
- type: number - type: 'null'
- type: boolean - type: array
- type: 'null' items:
- type: array oneOf:
items: - type: string
oneOf: - type: integer
- type: string - type: number
- type: integer - type: boolean
- type: number - type: 'null'
- type: boolean - type: object
- type: 'null' additionalProperties:
- type: object oneOf:
additionalProperties: - type: string
oneOf: - type: integer
- type: string - type: number
- type: integer - type: boolean
- type: number - type: 'null'
- type: boolean
- type: 'null'
arguments_json:
type: string
additionalProperties: false additionalProperties: false
required: required:
- call_id - call_id

View file

@ -25,15 +25,64 @@ from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
SamplingParams,
StopReason, StopReason,
ToolCall, ToolCall,
ToolDefinition, ToolDefinition,
ToolParamDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
register_schema(ToolCall)
register_schema(ToolParamDefinition)
register_schema(ToolDefinition)
@json_schema_type
class GreedySamplingStrategy(BaseModel):
type: Literal["greedy"] = "greedy"
@json_schema_type
class TopPSamplingStrategy(BaseModel):
type: Literal["top_p"] = "top_p"
temperature: Optional[float] = Field(..., gt=0.0)
top_p: Optional[float] = 0.95
@json_schema_type
class TopKSamplingStrategy(BaseModel):
type: Literal["top_k"] = "top_k"
top_k: int = Field(..., ge=1)
SamplingStrategy = Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
]
register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type
class SamplingParams(BaseModel):
"""Sampling parameters.
:param strategy: The sampling strategy.
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
your prompt plus max_tokens cannot exceed the model's context length.
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
:param stop: Up to 4 sequences where the API will stop generating further tokens.
The returned text will not contain the stop sequence.
"""
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
class LogProbConfig(BaseModel): class LogProbConfig(BaseModel):
""" """

View file

@ -29,8 +29,8 @@ from rich.progress import (
from termcolor import cprint from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import LlamaDownloadInfo from llama_stack.models.llama.sku_list import LlamaDownloadInfo
from llama_stack.models.llama.sku_types import Model
class Download(Subcommand): class Download(Subcommand):

View file

@ -63,17 +63,6 @@ class ModelDescribe(Subcommand):
("Model params.json", json.dumps(model.arch_args, indent=4)), ("Model params.json", json.dumps(model.arch_args, indent=4)),
] ]
if model.recommended_sampling_params is not None:
sampling_params = model.recommended_sampling_params.model_dump()
for k in ("max_tokens", "repetition_penalty"):
del sampling_params[k]
rows.append(
(
"Recommended sampling params",
json.dumps(sampling_params, indent=4),
)
)
print_table( print_table(
rows, rows,
headers, headers,

View file

@ -11,7 +11,7 @@ from pathlib import Path
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table from llama_stack.cli.table import print_table
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family from llama_stack.models.llama.sku_types import CoreModelId, ModelFamily, is_multimodal, model_family
ROOT_DIR = Path(__file__).parent.parent.parent ROOT_DIR = Path(__file__).parent.parent.parent

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, Optional from typing import Any, Dict
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams
from llama_stack.models.llama.sku_list import LlamaDownloadInfo from llama_stack.models.llama.sku_list import LlamaDownloadInfo
from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
class PromptGuardModel(BaseModel): class PromptGuardModel(BaseModel):
@ -23,7 +23,6 @@ class PromptGuardModel(BaseModel):
is_instruct_model: bool = False is_instruct_model: bool = False
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16 quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
arch_args: Dict[str, Any] = Field(default_factory=dict) arch_args: Dict[str, Any] = Field(default_factory=dict)
recommended_sampling_params: Optional[SamplingParams] = None
def descriptor(self) -> str: def descriptor(self) -> str:
return self.model_id return self.model_id

View file

@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import concurrent.futures
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
if new_mp_size % old_mp_size == 0:
# Read old MP shard and split it into smaller ones
return [new_mp_rank * old_mp_size // new_mp_size]
elif old_mp_size % new_mp_size == 0:
# Merge old MP shards into a single one
mp_factor = old_mp_size // new_mp_size
return list(range(new_mp_rank * mp_factor, (new_mp_rank + 1) * mp_factor))
else:
raise ValueError(
f"Either old MP size or new MP size should be a multiple of the other: "
f"{old_mp_size} % {new_mp_size} != 0 and {new_mp_size} % {old_mp_size} != 0"
)
def maybe_reshard_state_dict(
ckpt_paths: List[Path],
n_kv_heads: int,
moe_num_experts: Optional[int] = None,
map_location: Union[str, torch.device] = "cpu",
mmap: bool = True,
) -> Dict[str, torch.Tensor]:
if str(map_location) == "cpu":
torch.set_default_tensor_type(torch.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
ckpt_paths = np.array(sorted(ckpt_paths))
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
old_mp_size = len(ckpt_paths)
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
paths = ckpt_paths[old_mp_ranks] # type: ignore
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
if new_mp_size == old_mp_size:
return state_dicts[0] # type: ignore
if moe_num_experts is not None:
state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts]
print(f"Resharding {len(state_dicts)} state dicts from MP size {old_mp_size} to MP size {new_mp_size}")
return reshard_mp(
state_dicts,
size=max(new_mp_size // old_mp_size, 1),
rank=new_mp_rank % max(new_mp_size // old_mp_size, 1),
repeat_qk_qv=max(new_mp_size // n_kv_heads, 1),
)
_WEIGHT_ROW_KEY = {
"feed_forward.w2",
"feed_forward.mlp.fc2",
"attention.wo",
"feed_forward.mlp.fc2_weight",
"feed_forward.w_out_shared_DF.weight",
"attn.wo.weight",
"mlp.c_proj.weight",
}
_MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"}
_WEIGHT_COLUMN_KEY = {
"output",
"feed_forward.(w1|w3)",
"feed_forward.mlp.(fc1|fc3)",
"feed_forward.mlp.fc1_weight",
"attention.(wk|wq|wv|wqkv).weight",
"feed_forward.(w_in_shared_FD|w_swiglu_FD)",
"attn.(wk|wq|wv).weight",
"attn.(wk|wq|wv).bias",
"mlp.c_fc.weight",
"mlp.c_fc.bias",
"conv1._linear.weight",
"tok_embeddings.weight",
"vision_projection.weight",
}
_MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
def reshard_mp(
state_dicts: List[Dict[str, torch.Tensor]],
size: int,
rank: int,
repeat_qk_qv: int = 1,
) -> Dict[str, torch.Tensor]:
"""
Reshard a list of state dicts into a single state dict given a change in MP size.
If the list has more than one state dict, we concatenate the values of the same
key across all state dicts. Otherwise, we just slice it for the current MP rank.
"""
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
if len(tensors) > 1:
return torch.cat(tensors, dim=dim)
return tensors[0].chunk(size, dim=dim)[rank].clone()
def process_key(key: str) -> torch.Tensor:
if row_regex.search(key):
return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
elif column_regex.search(key):
if "w13" in key or "fc1_weight" in key:
dims = state_dicts[0][key].size()
values = [s[key].view(2, dims[0] // 2, *dims[1:]) for s in state_dicts]
return concat_or_chunk(values, dim=1).flatten(0, 1)
elif "qkv" in key:
q_dim = state_dicts[0][key.replace("qkv", "o")].size(1)
kv_dim = (state_dicts[0][key].size(0) - q_dim) // 2
values = [s[key].split((q_dim, kv_dim, kv_dim)) for s in state_dicts]
return torch.cat([concat_or_chunk(x, dim=0) for x in zip(*values, strict=False)]) # type: ignore
elif "wk.weight" in key or "wv.weight" in key:
# Support MP > #kv_head
return concat_or_chunk([s[key].repeat(repeat_qk_qv, 1) for s in state_dicts], dim=0)
elif key == "output.bias" or key == "fc.weight":
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
elif "w_" in key:
return concat_or_chunk([s[key] for s in state_dicts], dim=-2)
else:
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
else:
return state_dicts[0][key].clone()
row_keys = _WEIGHT_ROW_KEY | _MOE_WEIGHT_ROW_KEY
column_keys = _WEIGHT_COLUMN_KEY | _MOE_WEIGHT_COLUMN_KEY
column_regex = re.compile("|".join(column_keys))
row_regex = re.compile("|".join(row_keys))
output: Dict[str, torch.Tensor] = {}
with concurrent.futures.ThreadPoolExecutor() as executor:
# Note: only processes keys in the first state dict.
# Assumes keys are the same across all state dicts.
mappings = {executor.submit(process_key, key): key for key in state_dicts[0]}
for future in concurrent.futures.as_completed(mappings):
output[mappings[future]] = future.result()
return output
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
routed_regex = re.compile("|".join(routed_keys))
keys = list(state_dict.keys())
for key in keys:
if routed_regex.search(key):
state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0)
return state_dict

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import base64 import base64
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
@ -19,8 +12,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema
# The goal is that these set of types are relevant for all Llama models. # The goal is that these set of types are relevant for all Llama models.
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to # That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
# the llama3 series of models. # the llama3 series of models.
@ -47,14 +38,7 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel): class ToolCall(BaseModel):
call_id: str call_id: str
tool_name: Union[BuiltinTool, str] tool_name: Union[BuiltinTool, str]
# Plan is to deprecate the Dict in favor of a JSON string arguments: Dict[str, RecursiveType]
# that is parsed on the client side instead of trying to manage
# the recursive type here.
# Making this a union so that client side can start prepping for this change.
# Eventually, we will remove both the Dict and arguments_json field,
# and arguments will just be a str
arguments: Union[str, Dict[str, RecursiveType]]
arguments_json: Optional[str] = None
@field_validator("tool_name", mode="before") @field_validator("tool_name", mode="before")
@classmethod @classmethod
@ -98,6 +82,29 @@ class StopReason(Enum):
out_of_tokens = "out_of_tokens" out_of_tokens = "out_of_tokens"
class ToolParamDefinition(BaseModel):
param_type: str
description: Optional[str] = None
required: Optional[bool] = True
default: Optional[Any] = None
class ToolDefinition(BaseModel):
tool_name: Union[BuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
class RawMediaItem(BaseModel): class RawMediaItem(BaseModel):
type: Literal["image"] = "image" type: Literal["image"] = "image"
data: bytes | BytesIO data: bytes | BytesIO
@ -140,292 +147,25 @@ class RawMessage(BaseModel):
tool_calls: List[ToolCall] = Field(default_factory=list) tool_calls: List[ToolCall] = Field(default_factory=list)
register_schema(ToolCall) class GenerationResult(BaseModel):
token: int
text: str
logprobs: Optional[List[float]] = None
source: Literal["input"] | Literal["output"]
# index within the batch
batch_idx: int
# whether generation for this item is already finished. note that tokens can
# get returned even afterwards since other items in the batch can still be generating tokens
finished: bool
# because a batch is parallel processed, useful decoding for one item can correspond to processing
# pad tokens or tokens beyond EOS for other items. we could have decided to return None for this case
# but it's more convenient to return a list of GenerationResult and filter out the ignored tokens
ignore_token: bool
@json_schema_type class QuantizationMode(str, Enum):
class ToolParamDefinition(BaseModel): none = "none"
param_type: str fp8_mixed = "fp8_mixed"
description: Optional[str] = None int4_mixed = "int4_mixed"
required: Optional[bool] = True
default: Optional[Any] = None
@json_schema_type
class ToolDefinition(BaseModel):
tool_name: Union[BuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
@json_schema_type
class GreedySamplingStrategy(BaseModel):
type: Literal["greedy"] = "greedy"
@json_schema_type
class TopPSamplingStrategy(BaseModel):
type: Literal["top_p"] = "top_p"
temperature: Optional[float] = Field(..., gt=0.0)
top_p: Optional[float] = 0.95
@json_schema_type
class TopKSamplingStrategy(BaseModel):
type: Literal["top_k"] = "top_k"
top_k: int = Field(..., ge=1)
SamplingStrategy = Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
]
register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type
class SamplingParams(BaseModel):
"""Sampling parameters.
:param strategy: The sampling strategy.
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
your prompt plus max_tokens cannot exceed the model's context length.
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
:param stop: Up to 4 sequences where the API will stop generating further tokens.
The returned text will not contain the stop sequence.
"""
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
class CheckpointQuantizationFormat(Enum):
# default format
bf16 = "bf16"
# used for enabling fp8_rowwise inference, some weights are bf16
fp8_mixed = "fp8-mixed"
int8 = "int8"
int4 = "int4"
class ModelFamily(Enum):
llama2 = "llama2"
llama3 = "llama3"
llama3_1 = "llama3_1"
llama3_2 = "llama3_2"
llama3_3 = "llama3_3"
llama4 = "llama4"
safety = "safety"
class CoreModelId(Enum):
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
# Llama 2 family
llama2_7b = "Llama-2-7b"
llama2_13b = "Llama-2-13b"
llama2_70b = "Llama-2-70b"
llama2_7b_chat = "Llama-2-7b-chat"
llama2_13b_chat = "Llama-2-13b-chat"
llama2_70b_chat = "Llama-2-70b-chat"
# Llama 3 family
llama3_8b = "Llama-3-8B"
llama3_70b = "Llama-3-70B"
llama3_8b_instruct = "Llama-3-8B-Instruct"
llama3_70b_instruct = "Llama-3-70B-Instruct"
# Llama 3.1 family
llama3_1_8b = "Llama3.1-8B"
llama3_1_70b = "Llama3.1-70B"
llama3_1_405b = "Llama3.1-405B"
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
# Llama 3.2 family
llama3_2_1b = "Llama3.2-1B"
llama3_2_3b = "Llama3.2-3B"
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
llama3_2_11b_vision = "Llama3.2-11B-Vision"
llama3_2_90b_vision = "Llama3.2-90B-Vision"
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
# Llama 3.3 family
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
# Llama 4 family
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
# Safety models
llama_guard_3_8b = "Llama-Guard-3-8B"
llama_guard_2_8b = "Llama-Guard-2-8B"
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
llama_guard_3_1b = "Llama-Guard-3-1B"
def is_multimodal(model_id) -> bool:
if model_id in [
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return True
else:
return False
def model_family(model_id) -> ModelFamily:
if model_id in [
CoreModelId.llama2_7b,
CoreModelId.llama2_13b,
CoreModelId.llama2_70b,
CoreModelId.llama2_7b_chat,
CoreModelId.llama2_13b_chat,
CoreModelId.llama2_70b_chat,
]:
return ModelFamily.llama2
elif model_id in [
CoreModelId.llama3_8b,
CoreModelId.llama3_70b,
CoreModelId.llama3_8b_instruct,
CoreModelId.llama3_70b_instruct,
]:
return ModelFamily.llama3
elif model_id in [
CoreModelId.llama3_1_8b,
CoreModelId.llama3_1_70b,
CoreModelId.llama3_1_405b,
CoreModelId.llama3_1_8b_instruct,
CoreModelId.llama3_1_70b_instruct,
CoreModelId.llama3_1_405b_instruct,
]:
return ModelFamily.llama3_1
elif model_id in [
CoreModelId.llama3_2_1b,
CoreModelId.llama3_2_3b,
CoreModelId.llama3_2_1b_instruct,
CoreModelId.llama3_2_3b_instruct,
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return ModelFamily.llama3_2
elif model_id in [
CoreModelId.llama3_3_70b_instruct,
]:
return ModelFamily.llama3_3
elif model_id in [
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_scout_17b_16e_instruct,
CoreModelId.llama4_maverick_17b_128e,
CoreModelId.llama4_maverick_17b_128e_instruct,
]:
return ModelFamily.llama4
elif model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_2_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return ModelFamily.safety
else:
raise ValueError(f"Unknown model family for {model_id}")
class Model(BaseModel):
core_model_id: CoreModelId
description: str
huggingface_repo: Optional[str] = None
recommended_sampling_params: Optional[SamplingParams] = None
arch_args: Dict[str, Any]
variant: str = ""
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
pth_file_count: int
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
# silence pydantic until we remove the `model_` fields
model_config = ConfigDict(protected_namespaces=())
@property
def model_family(self) -> ModelFamily:
return model_family(self.core_model_id)
# The SKU is uniquely identified by (model_id, variant) combo
def descriptor(self, shorten_default_variant: bool = True) -> str:
if not self.variant:
return self.core_model_id.value
return f"{self.core_model_id.value}:{self.variant}"
@property
def is_instruct_model(self) -> bool:
return "instruct" in self.id.name
# Featured models are shown in the non-exhaustive model list
@property
def is_featured(self) -> bool:
return self.model_family in [
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
ModelFamily.safety,
]
@property
def max_seq_length(self) -> int:
if self.model_family == ModelFamily.llama2:
return 4096
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
return 4096
elif self.model_family == ModelFamily.llama3:
return 8192
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
return 131072
elif self.model_family == ModelFamily.llama3_2:
if self.quantization_format == CheckpointQuantizationFormat.int4:
return 8192
return 131072
elif self.model_family == ModelFamily.llama4:
if self.core_model_id in {
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_maverick_17b_128e,
}:
return 262144
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
return 10485760
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
return 1048576
elif self.core_model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return 131072
else:
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import io import io
import json import json
import uuid import uuid

View file

@ -4,59 +4,37 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. # All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import json import json
import os import os
import sys import sys
import time import time
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Callable, Generator, List, Optional from typing import Callable, Generator, List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import ( from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel, initialize_model_parallel,
model_parallel_is_initialized, model_parallel_is_initialized,
) )
from termcolor import cprint from termcolor import cprint
from ..datatypes import RawContent, RawMessage, StopReason, ToolPromptFormat from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
from .args import ModelArgs from .args import ModelArgs
from .chat_format import ChatFormat, LLMInput from .chat_format import ChatFormat, LLMInput
from .model import Transformer from .model import Transformer
from .multimodal.model import CrossAttentionTransformer
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
@dataclass
class CompletionPrediction:
generation: str
decoded_tokens: Optional[List[str]] = None
logprobs: Optional[List[List[float]]] = None
@dataclass
class ChatPrediction:
generation: RawMessage
decoded_tokens: Optional[List[str]] = None
logprobs: Optional[List[List[float]]] = None
@dataclass
class TokenResult:
token: int
text: str
logprobs: Optional[List[float]] = None
# TODO: make this completely parallel to the llama4 generation.py file and share common code
# from llama-models also
class Llama3: class Llama3:
@staticmethod @staticmethod
def build( def build(
@ -64,7 +42,7 @@ class Llama3:
max_seq_len: int, max_seq_len: int,
max_batch_size: int, max_batch_size: int,
world_size: Optional[int] = None, world_size: Optional[int] = None,
tokenizer_path: Optional[str] = None, quantization_mode: Optional[QuantizationMode] = None,
seed: int = 1, seed: int = 1,
device: str = "cuda", device: str = "cuda",
): ):
@ -101,13 +79,9 @@ class Llama3:
start_time = time.time() start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
assert world_size == len(checkpoints), ( print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
)
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f: with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read()) params = json.loads(f.read())
@ -116,40 +90,58 @@ class Llama3:
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
**params, **params,
) )
if tokenizer_path: tokenizer = Tokenizer.get_instance()
tokenizer = Tokenizer(model_path=tokenizer_path)
else: state_dict = maybe_reshard_state_dict(
tokenizer = Tokenizer.get_instance() ckpt_paths,
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
)
assert model_args.vocab_size == tokenizer.n_words assert model_args.vocab_size == tokenizer.n_words
torch.set_default_device(device)
if device.type == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
elif device.type == "xpu":
if torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
else:
torch.set_default_dtype(torch.half)
if model_args.vision_chunk_size > 0: def build_model():
from .multimodal.model import CrossAttentionTransformer if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
else:
model = Transformer(model_args)
return model
model = CrossAttentionTransformer(model_args) if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
model.setup_cache(model_args.max_batch_size, torch.get_default_dtype()) from .quantization.loader import convert_to_quantized_model
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = build_model()
print("Loading state dict...")
model.load_state_dict(state_dict, strict=False)
print("Done...")
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
torch.set_default_device(device)
else: else:
model = Transformer(model_args) print(f"Setting default device to {device}")
model.load_state_dict(checkpoint, strict=True) torch.set_default_device(device)
model.to(device) if device.type == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
elif device.type == "xpu":
if torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
model = build_model()
print("Loading state dict...")
model.load_state_dict(state_dict, strict=True)
model.to(device)
print("Done...")
print(f"Loaded in {time.time() - start_time:.2f} seconds") print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args) return Llama3(model, tokenizer, model_args)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs): def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args self.args = args
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -158,26 +150,30 @@ class Llama3:
@torch.inference_mode() @torch.inference_mode()
def generate( def generate(
self, self,
model_input: LLMInput, model_inputs: List[LLMInput],
max_gen_len: int,
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False, logprobs: bool = False,
echo: bool = False, echo: bool = False,
print_model_input: bool = False, print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator: ) -> Generator[List[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
params = self.model.params params = self.model.params
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input: if print_model_input:
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens] for inp in model_inputs:
cprint( tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n", cprint(
"red", "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
) "red",
prompt_tokens = [model_input.tokens] )
prompt_tokens = [inp.tokens for inp in model_inputs]
bsz = 1 bsz = len(model_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens) min_prompt_len = min(len(t) for t in prompt_tokens)
@ -189,18 +185,6 @@ class Llama3:
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
is_vision = not isinstance(self.model, Transformer)
if is_vision:
images = model_input.vision.images if model_input.vision is not None else []
mask = model_input.vision.mask if model_input.vision is not None else []
# the method works for bsz > 1 so add a batch dimension
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=[images],
batch_masks=[mask],
total_len=total_len,
)
pad_id = self.tokenizer.pad_id pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long) tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
for k, t in enumerate(prompt_tokens): for k, t in enumerate(prompt_tokens):
@ -208,23 +192,45 @@ class Llama3:
if logprobs: if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float) token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
prev_pos = 0 is_vision = not isinstance(self.model, Transformer)
if is_vision:
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=images,
batch_masks=mask,
total_len=total_len,
device=tokens.device,
)
eos_reached = torch.tensor([False] * bsz) eos_reached = torch.tensor([False] * bsz)
input_text_mask = tokens != pad_id input_text_mask = tokens != pad_id
if echo: if echo:
for i, t in enumerate(model_input.tokens): for i in range(max_prompt_len):
yield TokenResult( results = []
token=t, for j, t in enumerate(tokens[:, i]):
text=self.tokenizer.decode([t]), results.append(
logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None), GenerationResult(
) token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="input",
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
batch_idx=j,
finished=False,
ignore_token=t.item() == pad_id,
)
)
yield results
stop_tokens = torch.tensor(self.tokenizer.stop_tokens) stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
prev_pos = 0
for cur_pos in range(min_prompt_len, total_len): for cur_pos in range(min_prompt_len, total_len):
if is_vision: if is_vision:
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long) position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
text_only_inference = model_input.vision is None text_only_inference = all(inp.vision is None for inp in model_inputs)
logits = self.model.forward( logits = self.model.forward(
position_ids, position_ids,
tokens, tokens,
@ -271,155 +277,69 @@ class Llama3:
ignore_index=pad_id, ignore_index=pad_id,
) )
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens)) eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
yield TokenResult( results = []
token=next_token[0].item(), for idx, t in enumerate(next_token):
text=self.tokenizer.decode(next_token.tolist()), results.append(
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None), GenerationResult(
) token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)
yield results
prev_pos = cur_pos prev_pos = cur_pos
if all(eos_reached): if all(eos_reached):
break break
def text_completion( def completion(
self, self,
content: RawContent, contents: List[RawContent],
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: Optional[int] = None,
logprobs: bool = False, logprobs: bool = False,
echo: bool = False, echo: bool = False,
) -> CompletionPrediction: ) -> Generator[List[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len: model_inputs = [self.formatter.encode_content(c) for c in contents]
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(content)
tokens = []
token_logprobs = []
decoded_tokens = []
for result in self.generate( for result in self.generate(
model_input=model_input, model_inputs=model_inputs,
max_gen_len=max_gen_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs, logprobs=logprobs,
echo=echo, echo=echo,
): ):
tokens.append(result.token) yield result
if logprobs: if all(r.finished for r in result):
decoded_tokens.append(result.text) break
token_logprobs.append(result.logprobs)
generation = self.tokenizer.decode(tokens)
if logprobs:
return CompletionPrediction(
generation=generation,
logprobs=token_logprobs,
decoded_tokens=decoded_tokens,
)
return CompletionPrediction(generation=generation)
def chat_completion( def chat_completion(
self, self,
messages: List[RawMessage], messages_batch: List[List[RawMessage]],
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: Optional[int] = None,
logprobs: bool = False, logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
echo: bool = False, echo: bool = False,
) -> ChatPrediction: ) -> Generator[List[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len: model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
max_gen_len = self.model.params.max_seq_len - 1
tokens = []
token_logprobs = []
decoded_tokens = []
stop_reason = None
for result in self.generate( for result in self.generate(
model_input=self.formatter.encode_dialog_prompt(messages, tool_prompt_format), model_inputs=model_inputs,
max_gen_len=max_gen_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs, logprobs=logprobs,
echo=echo, echo=echo,
): ):
tokens.append(result.token) yield result
if result.text == "<|eot_id|>": if all(r.finished for r in result):
stop_reason = StopReason.end_of_turn break
elif result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
if logprobs:
decoded_tokens.append(result.text)
token_logprobs.append(result.logprobs)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.formatter.decode_assistant_message(tokens, stop_reason)
if logprobs:
return ChatPrediction(
generation=message,
logprobs=token_logprobs,
decoded_tokens=decoded_tokens,
)
return ChatPrediction(generation=message)
def chat_completion_raw(
self,
messages: List[RawMessage],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> List[int]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
output_tokens = []
model_input = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
input_tokens = model_input.tokens
for result in self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=False,
):
output_tokens.append(result.token)
return input_tokens, output_tokens
def text_completion_raw(
self,
content: RawContent,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
):
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(content)
input_tokens = model_input.tokens
output_tokens = []
for result in self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=False,
):
output_tokens.append(result.token)
return input_tokens, output_tokens
def sample_top_p(probs, p): def sample_top_p(probs, p):

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple
@ -29,6 +19,10 @@ from torch import nn
from .args import ModelArgs from .args import ModelArgs
# **NOTE**: This code is not runnable without installing `torch` and `fairscale`
# dependencies. These dependencies are not part of the default dependencies
# (requirements.txt) of the `llama-models` package.
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6): def __init__(self, dim: int, eps: float = 1e-6):
@ -111,9 +105,9 @@ class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size() world_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size self.n_local_heads = args.n_heads // world_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_local_kv_heads = self.n_kv_heads // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads self.head_dim = args.dim // args.n_heads

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import logging import logging
import math import math
from functools import partial from functools import partial
@ -180,14 +170,14 @@ class ImageAttention(nn.Module):
n_heads, n_heads,
): ):
super().__init__() super().__init__()
model_parallel_size = fs_init.get_model_parallel_world_size() world_size = fs_init.get_model_parallel_world_size()
qkvo_replication = 1 qkvo_replication = 1
if model_parallel_size > 16: if world_size > 16:
qkvo_replication = model_parallel_size // 8 qkvo_replication = world_size // 8
self.n_kv_heads = n_heads self.n_kv_heads = n_heads
self.n_local_heads = n_heads * qkvo_replication // model_parallel_size self.n_local_heads = n_heads * qkvo_replication // world_size
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // model_parallel_size self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = dim // n_heads self.head_dim = dim // n_heads
@ -536,16 +526,16 @@ class Attention(nn.Module):
cache_v (torch.Tensor): Cached values for attention. cache_v (torch.Tensor): Cached values for attention.
""" """
super().__init__() super().__init__()
model_parallel_size = fs_init.get_model_parallel_world_size() world_size = fs_init.get_model_parallel_world_size()
replication_factor = 1 replication_factor = 1
if model_parallel_size > 8: if world_size > 8:
replication_factor = model_parallel_size // MP_SCALE replication_factor = world_size // MP_SCALE
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_kv_heads *= replication_factor self.n_kv_heads *= replication_factor
self.n_local_heads = args.n_heads // model_parallel_size self.n_local_heads = args.n_heads // world_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_local_kv_heads = self.n_kv_heads // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads self.head_dim = args.dim // args.n_heads
self.max_seq_len = args.max_seq_len self.max_seq_len = args.max_seq_len
@ -587,13 +577,11 @@ class Attention(nn.Module):
self.n_local_kv_heads, self.n_local_kv_heads,
self.head_dim, self.head_dim,
) )
device = next(self.parameters()).device
self.register_buffer( self.register_buffer(
"key_cache", "key_cache",
torch.zeros( torch.zeros(
cache_shape, cache_shape,
dtype=dtype, dtype=dtype,
device=device,
), ),
persistent=False, persistent=False,
) )
@ -602,7 +590,6 @@ class Attention(nn.Module):
torch.zeros( torch.zeros(
cache_shape, cache_shape,
dtype=dtype, dtype=dtype,
device=device,
), ),
persistent=False, persistent=False,
) )
@ -614,6 +601,9 @@ class Attention(nn.Module):
freqs_cis: torch.Tensor, freqs_cis: torch.Tensor,
position_ids: torch.LongTensor, position_ids: torch.LongTensor,
): ):
self.key_cache = self.key_cache.to(x.device)
self.value_cache = self.value_cache.to(x.device)
xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]] xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
bs, slen, _ = xq.shape bs, slen, _ = xq.shape
@ -832,10 +822,10 @@ class CrossAttention(torch.nn.Module):
norm_eps: float, norm_eps: float,
): ):
super().__init__() super().__init__()
self.model_parallel_size = fs_init.get_model_parallel_world_size() self.world_size = fs_init.get_model_parallel_world_size()
replication_factor = 1 replication_factor = 1
if self.model_parallel_size > 8: if self.world_size > 8:
replication_factor = self.model_parallel_size // MP_SCALE replication_factor = self.world_size // MP_SCALE
n_kv_heads *= replication_factor n_kv_heads *= replication_factor
assert n_heads % n_kv_heads == 0 assert n_heads % n_kv_heads == 0
@ -889,10 +879,10 @@ class CrossAttention(torch.nn.Module):
# trunk LLM (i.e., group query attention) -- @dubeya # trunk LLM (i.e., group query attention) -- @dubeya
# local heads # local heads
assert self.n_heads % self.n_kv_heads == 0 assert self.n_heads % self.n_kv_heads == 0
assert self.n_heads % self.model_parallel_size == 0 assert self.n_heads % self.world_size == 0
assert self.n_kv_heads % self.model_parallel_size == 0 assert self.n_kv_heads % self.world_size == 0
self.n_local_heads = self.n_heads // self.model_parallel_size self.n_local_heads = self.n_heads // self.world_size
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size self.n_local_kv_heads = self.n_kv_heads // self.world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads self.n_rep = self.n_local_heads // self.n_local_kv_heads
def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
@ -1041,7 +1031,7 @@ class CrossAttentionTransformerVision(torch.nn.Module):
self.image_res = args.vision_chunk_size self.image_res = args.vision_chunk_size
self.max_num_chunks = args.vision_max_num_chunks self.max_num_chunks = args.vision_max_num_chunks
if return_intermediate is not None: if return_intermediate is not None:
return_intermediate = [int(level) for level in return_intermediate.split(",")] return_intermediate = [int(layer) for layer in return_intermediate.split(",")]
self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
self.patch_size = 14 self.patch_size = 14
self.vision_encoder = VisionEncoder( self.vision_encoder = VisionEncoder(
@ -1076,15 +1066,15 @@ class CrossAttentionTransformerText(torch.nn.Module):
def __init__(self, args: ModelArgs) -> None: def __init__(self, args: ModelArgs) -> None:
super().__init__() super().__init__()
self.model_parallel_size = fs_init.get_model_parallel_world_size() self.world_size = fs_init.get_model_parallel_world_size()
assert args.vocab_size > 0 assert args.vocab_size > 0
self.vocab_size = args.vocab_size self.vocab_size = args.vocab_size
self.n_layers = args.n_layers self.n_layers = args.n_layers
self.dim = args.dim self.dim = args.dim
self.head_dim = args.dim // args.n_heads self.head_dim = args.dim // args.n_heads
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size self.n_local_kv_heads = self.n_kv_heads // self.world_size
assert self.vocab_size % self.model_parallel_size == 0 assert self.vocab_size % self.world_size == 0
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x) self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
self.pos_embeddings = None self.pos_embeddings = None
# final norm layer (not necessary for post-norm) # final norm layer (not necessary for post-norm)
@ -1184,6 +1174,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
text_only_inference: bool = False, text_only_inference: bool = False,
): ):
assert self.cache_is_setup, "Please set up cache before calling forward" assert self.cache_is_setup, "Please set up cache before calling forward"
self.mask_cache = self.mask_cache.to(h.device)
self.freqs_cis = self.freqs_cis.to(h.device)
mask = self.mask_cache.index_select(2, position_ids) mask = self.mask_cache.index_select(2, position_ids)
freqs_cis = self.freqs_cis.index_select(0, position_ids) freqs_cis = self.freqs_cis.index_select(0, position_ids)
@ -1212,9 +1204,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
output = gather_from_tensor_model_parallel_region(output) output = gather_from_tensor_model_parallel_region(output)
return output.float() return output.float()
def setup_cache(self, max_batch_size: int, dtype=torch.bfloat16): def setup_cache(self, max_batch_size: int, device: torch.device, dtype=torch.bfloat16):
# Set up the text kv caches # Set up the text kv caches
device = next(self.parameters()).device
ones = torch.ones( ones = torch.ones(
(self.max_seq_len, self.max_seq_len), (self.max_seq_len, self.max_seq_len),
dtype=torch.bool, dtype=torch.bool,
@ -1265,7 +1256,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
return ( return (
cross_attention_masks.to(device=text_device, dtype=text_dtype), cross_attention_masks.to(device=text_device, dtype=text_dtype),
full_text_row_masked_out_mask, full_text_row_masked_out_mask.to(device=text_device),
) )
@ -1284,14 +1275,15 @@ class CrossAttentionTransformer(torch.nn.Module):
max_num_chunks=args.vision_max_num_chunks, max_num_chunks=args.vision_max_num_chunks,
) )
def setup_cache(self, max_batch_size: int, dtype: torch.dtype): def setup_cache(self, max_batch_size: int, device: torch.device, dtype: torch.dtype):
self.text_model.setup_cache(max_batch_size, dtype) self.text_model.setup_cache(max_batch_size, device, dtype)
def compute_vision_tokens_masks( def compute_vision_tokens_masks(
self, self,
batch_images: List[List[PIL_Image.Image]], batch_images: List[List[PIL_Image.Image]],
batch_masks: List[List[List[int]]], batch_masks: List[List[List[int]]],
total_len: int, total_len: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
skip_vision_encoder = False skip_vision_encoder = False
@ -1318,6 +1310,7 @@ class CrossAttentionTransformer(torch.nn.Module):
image_res=self.params.vision_chunk_size, image_res=self.params.vision_chunk_size,
max_num_images=max_num_images, max_num_images=max_num_images,
) )
stacked_images = stacked_images.to(device=device)
if skip_vision_encoder: if skip_vision_encoder:
vision_tokens = torch.zeros( vision_tokens = torch.zeros(
@ -1330,7 +1323,7 @@ class CrossAttentionTransformer(torch.nn.Module):
), ),
) )
else: else:
vision_tokens = self.vision_model(stacked_images, aspect_ratios) vision_tokens = self.vision_model(stacked_images, aspect_ratios).to(device=device)
bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape) bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
xattn_caches = torch.stack( xattn_caches = torch.stack(

View file

@ -15,7 +15,7 @@ import textwrap
from datetime import datetime from datetime import datetime
from typing import Any, List, Optional from typing import Any, List, Optional
from llama_stack.models.llama.datatypes import ( from llama_stack.apis.inference import (
BuiltinTool, BuiltinTool,
ToolDefinition, ToolDefinition,
ToolParamDefinition, ToolParamDefinition,

View file

@ -4,9 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
# type: ignore # type: ignore
import os import os
from typing import Any, Dict, List, Optional, cast from typing import Any, Dict, List, Optional, cast
@ -18,22 +15,15 @@ from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_regi
from torch import Tensor, nn from torch import Tensor, nn
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType from ...datatypes import QuantizationMode
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import resolve_model
from ...config import MetaReferenceQuantizedInferenceConfig
from ...datatypes import CheckpointQuantizationFormat
from ...quantize_impls import ( from ...quantize_impls import (
Fp8ScaledWeights, Fp8ScaledWeights,
ffn_swiglu, ffn_swiglu,
load_fp8, load_fp8,
quantize_fp8, quantize_fp8,
) )
from ..args import ModelArgs
from ..model import Transformer, TransformerBlock from ..model import Transformer, TransformerBlock
from ..multimodal.model import CrossAttentionTransformer
log = get_logger(__name__, category="quantization")
def swiglu_wrapper( def swiglu_wrapper(
@ -44,30 +34,34 @@ def swiglu_wrapper(
return reduce_from_model_parallel_region(out) return reduce_from_model_parallel_region(out)
def convert_to_quantized_model(
model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str,
quantization_mode: Optional[str] = None,
fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
) -> Transformer | CrossAttentionTransformer:
if quantization_mode == QuantizationMode.fp8_mixed:
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
elif quantization_mode == QuantizationMode.int4_mixed:
return convert_to_int4_quantized_model(model, checkpoint_dir, device)
else:
raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
def convert_to_fp8_quantized_model( def convert_to_fp8_quantized_model(
model: Transformer, model: Transformer,
config: MetaReferenceQuantizedInferenceConfig,
checkpoint_dir: str, checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0, fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
) -> Transformer: ) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
return model
elif config.quantization.type != QuantizationType.fp8.value:
raise ValueError("Only FP8 quantization is supported")
assert config.model is not None, "Model must be specified for quantized inference"
llama_model = resolve_model(config.model)
assert llama_model is not None, f"Model {config.model} not found"
# Move weights to GPU with quantization # Move weights to GPU with quantization
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value: fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
log.info("Loading fp8 scales...") if os.path.isfile(fp8_scales_path):
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt") print("Loading fp8 scales...")
assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True) fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers: for _, block in model.named_modules():
if isinstance(block, TransformerBlock): if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue continue
@ -81,8 +75,8 @@ def convert_to_fp8_quantized_model(
fp8_activation_scale_ub, fp8_activation_scale_ub,
) )
else: else:
log.info("Quantizing fp8 weights from bf16...") print("Quantizing fp8 weights from bf16...")
for block in model.layers: for _, block in model.named_modules():
if isinstance(block, TransformerBlock): if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1): if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue continue
@ -92,12 +86,12 @@ def convert_to_fp8_quantized_model(
param.weight = quantize_fp8( param.weight = quantize_fp8(
param.weight, param.weight,
fp8_activation_scale_ub, fp8_activation_scale_ub,
output_device=torch.device("cuda"), output_device=device,
) )
for _, parameter in model.named_parameters(): for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights): if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda") parameter.data = parameter.to(device=device)
return model return model
@ -290,11 +284,12 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
def convert_to_int4_quantized_model( def convert_to_int4_quantized_model(
model: Transformer, model: Transformer | CrossAttentionTransformer,
model_args: ModelArgs, checkpoint_dir: str,
) -> Transformer: device: Optional[torch.device] = None,
) -> Transformer | CrossAttentionTransformer:
"""Convert the model to int4 quantized model.""" """Convert the model to int4 quantized model."""
model_args = model.params
assert model_args.quantization_args is not None, "Quantization args must be specified." assert model_args.quantization_args is not None, "Quantization args must be specified."
quantization_args = model_args.quantization_args quantization_args = model_args.quantization_args
if quantization_args.scheme is None: if quantization_args.scheme is None:
@ -318,5 +313,4 @@ def convert_to_int4_quantized_model(
lora_scale = model_args.lora_args.scale lora_scale = model_args.lora_args.scale
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale) _prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") return cast(Transformer | CrossAttentionTransformer, model.to(device=device))
return cast(Transformer, model.to(device))

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os import os
from logging import getLogger from logging import getLogger
from pathlib import Path from pathlib import Path

View file

@ -3,10 +3,3 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

View file

@ -4,12 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import json import json
import textwrap import textwrap

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap import textwrap
from pathlib import Path from pathlib import Path

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional

View file

@ -12,6 +12,7 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
from PIL import Image as PIL_Image from PIL import Image as PIL_Image
# TODO: either fork these or move them to the common package
from ..datatypes import ( from ..datatypes import (
BuiltinTool, BuiltinTool,
RawContent, RawContent,
@ -26,10 +27,7 @@ from ..datatypes import (
from ..llama3.tool_utils import ToolUtils from ..llama3.tool_utils import ToolUtils
from .args import VisionArgs from .args import VisionArgs
from .datatypes import LLMInput from .datatypes import LLMInput
from .preprocess import ( from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
ResizeNormalizeImageTransform,
VariableSizeImageTransform,
)
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
@ -50,7 +48,7 @@ class TransformedImage:
aspect_ratio: Tuple[int, int] aspect_ratio: Tuple[int, int]
def convert_rgba_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image: def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
if image.mode == "RGBA": if image.mode == "RGBA":
image.load() # for png.split() image.load() # for png.split()
new_img = PIL_Image.new("RGB", image.size, bg) new_img = PIL_Image.new("RGB", image.size, bg)
@ -167,7 +165,7 @@ class ChatFormat:
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
image = PIL_Image.open(bytes_io) image = PIL_Image.open(bytes_io)
image = convert_rgba_to_rgb(image) image = convert_image_to_rgb(image)
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks) image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
if image_tiles.shape[0] > 1: if image_tiles.shape[0] > 1:
@ -212,12 +210,9 @@ class ChatFormat:
content = ToolUtils.encode_tool_call(t, tool_prompt_format) content = ToolUtils.encode_tool_call(t, tool_prompt_format)
_process_content(content) _process_content(content)
# Tool calls and Tool Response messages should be eom
eom = False eom = False
if message.role == "assistant": if message.role == "assistant":
eom = message.stop_reason == StopReason.end_of_message or message.tool_calls eom = message.stop_reason == StopReason.end_of_message
elif message.role == "tool":
eom = True
tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"]) tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"])
return tokens, images return tokens, images
@ -252,11 +247,6 @@ class ChatFormat:
if content.startswith(header_str): if content.startswith(header_str):
content = content[len(header_str) :] content = content[len(header_str) :]
ipython = content.startswith("<|python_start|>")
if ipython:
content = content[len("<|python_start|>") :]
content = content.replace("<|python_end|>", "")
if content.endswith("<|eot|>"): if content.endswith("<|eot|>"):
content = content[: -len("<|eot|>")] content = content[: -len("<|eot|>")]
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
@ -287,11 +277,6 @@ class ChatFormat:
} }
if tool_name in BuiltinTool.__members__: if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name] tool_name = BuiltinTool[tool_name]
elif ipython:
tool_name = BuiltinTool.code_interpreter
tool_arguments = {
"code": content,
}
tool_calls = [] tool_calls = []
if tool_name is not None and tool_arguments is not None: if tool_name is not None and tool_arguments is not None:

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Union

View file

@ -4,32 +4,43 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import codecs import codecs
import io import io
import json import json
import os import os
import sys import sys
import time import time
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Callable, Generator, List, Optional from typing import Callable, Generator, List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import ( from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel, initialize_model_parallel,
model_parallel_is_initialized, model_parallel_is_initialized,
) )
from termcolor import cprint from termcolor import cprint
from ..common import TokenResult from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode
from .args import ModelArgs from .args import ModelArgs
from .chat_format import ( from .chat_format import ChatFormat, RawContent, RawMessage
ChatFormat,
RawContent,
RawMessage,
)
from .datatypes import LLMInput, MaskedEmbedding, TransformerInput from .datatypes import LLMInput, MaskedEmbedding, TransformerInput
from .model import Transformer from .model import Transformer
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
@ -37,12 +48,6 @@ from .tokenizer import Tokenizer
torch.serialization.add_safe_globals([io.BytesIO, codecs.encode]) torch.serialization.add_safe_globals([io.BytesIO, codecs.encode])
class QuantizationMode(str, Enum):
none = "none"
fp8_mixed = "fp8_mixed"
int4_mixed = "int4_mixed"
class Llama4: class Llama4:
@staticmethod @staticmethod
def build( def build(
@ -50,7 +55,7 @@ class Llama4:
max_seq_len: int, max_seq_len: int,
max_batch_size: int, max_batch_size: int,
world_size: Optional[int] = None, world_size: Optional[int] = None,
quantization_mode: Optional[str] = None, quantization_mode: Optional[QuantizationMode] = None,
seed: int = 1, seed: int = 1,
): ):
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
@ -71,11 +76,9 @@ class Llama4:
start_time = time.time() start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
assert world_size == len(checkpoints), ( print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
)
with open(Path(ckpt_dir) / "params.json", "r") as f: with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read()) params = json.loads(f.read())
@ -92,10 +95,11 @@ class Llama4:
assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch" assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch"
print("Model args:\n", model_args.model_dump_json(indent=2)) print("Model args:\n", model_args.model_dump_json(indent=2))
ckpt_path = checkpoints[get_model_parallel_rank()] state_dict = maybe_reshard_state_dict(
print(f"Loading checkpoint from {ckpt_dir}...") ckpt_paths,
with open(ckpt_path, "rb") as f: n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
checkpoint = torch.load(f, map_location="cpu", weights_only=True) moe_num_experts=model_args.moe_args.num_experts,
)
print("Loaded checkpoint") print("Loaded checkpoint")
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed: if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
from .quantization.loader import convert_to_quantized_model from .quantization.loader import convert_to_quantized_model
@ -103,9 +107,9 @@ class Llama4:
torch.set_default_tensor_type(torch.BFloat16Tensor) torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args) model = Transformer(model_args)
print("Loading state dict...") print("Loading state dict...")
model.load_state_dict(checkpoint, strict=False) model.load_state_dict(state_dict, strict=False)
print("Done...") print("Done...")
model = convert_to_quantized_model(model, ckpt_dir) model = convert_to_quantized_model(model, ckpt_dir, quantization_mode)
else: else:
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
@ -114,7 +118,7 @@ class Llama4:
model = Transformer(model_args) model = Transformer(model_args)
print("Loading state dict...") print("Loading state dict...")
model.load_state_dict(checkpoint, strict=False) model.load_state_dict(state_dict, strict=False)
print("Done...") print("Done...")
print(f"Loaded in {time.time() - start_time:.2f} seconds") print(f"Loaded in {time.time() - start_time:.2f} seconds")
@ -129,7 +133,7 @@ class Llama4:
@torch.inference_mode() @torch.inference_mode()
def generate( def generate(
self, self,
llm_input: LLMInput, llm_inputs: List[LLMInput],
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: Optional[int] = None,
@ -137,22 +141,20 @@ class Llama4:
echo: bool = False, echo: bool = False,
print_model_input: bool = False, print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator: ) -> Generator[List[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len: if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
max_gen_len = self.model.args.max_seq_len - 1 max_gen_len = self.model.args.max_seq_len - 1
params = self.model.args params = self.model.args
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1" print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input and get_model_parallel_rank() == 0: if print_model_input:
tokens_to_print = list(llm_input.tokens) cprint("Input to model:\n", "yellow")
cprint( for inp in llm_inputs:
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n", cprint(self.tokenizer.decode(inp.tokens.tolist()), "grey")
"red", prompt_tokens = [inp.tokens for inp in llm_inputs]
)
prompt_tokens = [llm_input.tokens]
bsz = 1 bsz = len(llm_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens) min_prompt_len = min(len(t) for t in prompt_tokens)
@ -175,24 +177,33 @@ class Llama4:
input_text_mask = tokens != pad_id input_text_mask = tokens != pad_id
if echo: if echo:
for i, t in enumerate(llm_input.tokens): for i in range(max_prompt_len):
yield TokenResult( results = []
token=t, for j, t in enumerate(tokens[:, i]):
text=self.tokenizer.decode([t]), results.append(
logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None), GenerationResult(
) token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="input",
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
batch_idx=j,
finished=False,
ignore_token=t.item() == pad_id,
)
)
yield results
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda") stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
prev_pos = 0 prev_pos = 0
for cur_pos in range(min_prompt_len, total_len): for cur_pos in range(min_prompt_len, total_len):
image_embedding = None image_embedding = None
if prev_pos == 0 and llm_input.images is not None and len(llm_input.images) > 0: if prev_pos == 0 and any(inp.images is not None and len(inp.images) > 0 for inp in llm_inputs):
image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"] image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"]
image_mask = image_mask.unsqueeze(-1) image_mask = image_mask.unsqueeze(-1)
h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos]) h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos])
image_batch = [llm_input.images] image_batch = [inp.images if inp.images is not None else [] for inp in llm_inputs]
image_embedding = MaskedEmbedding( image_embedding = MaskedEmbedding(
embedding=self.model.vision_embeddings(image_batch, image_mask, h), embedding=self.model.vision_embeddings(image_batch, image_mask, h),
mask=image_mask, mask=image_mask,
@ -228,11 +239,21 @@ class Llama4:
ignore_index=pad_id, ignore_index=pad_id,
) )
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens)) eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
yield TokenResult(
token=next_token[0].item(), results = []
text=self.tokenizer.decode(next_token.tolist()), for idx, t in enumerate(next_token):
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None), results.append(
) GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)
yield results
prev_pos = cur_pos prev_pos = cur_pos
if all(eos_reached): if all(eos_reached):
@ -240,68 +261,47 @@ class Llama4:
def completion( def completion(
self, self,
content: RawContent, contents: List[RawContent],
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: Optional[int] = None,
logprobs: bool = False, logprobs: bool = False,
echo: bool = False, echo: bool = False,
) -> Generator: ) -> Generator[List[GenerationResult], None, None]:
llm_input = self.formatter.encode_content(content) llm_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate( for result in self.generate(
llm_input=llm_input, llm_inputs=llm_inputs,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
logprobs=logprobs, logprobs=logprobs,
echo=echo, echo=echo,
): ):
if result.token in self.tokenizer.stop_tokens:
break
yield result yield result
if all(r.finished for r in result):
break
def chat_completion( def chat_completion(
self, self,
messages: List[RawMessage], messages_batch: List[List[RawMessage]],
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: Optional[int] = None,
logprobs: bool = False, logprobs: bool = False,
echo: bool = False, echo: bool = False,
) -> Generator: ) -> Generator[List[GenerationResult], None, None]:
llm_input = self.formatter.encode_dialog_prompt(messages) llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate( for result in self.generate(
llm_input=llm_input, llm_inputs=llm_inputs,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
logprobs=logprobs, logprobs=logprobs,
echo=echo, echo=echo,
): ):
if result.token in self.tokenizer.stop_tokens:
break
yield result yield result
if all(r.finished for r in result):
def chat_completion_raw( break
self,
messages: List[RawMessage],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
):
llm_input = self.formatter.encode_dialog_prompt(messages)
output_tokens = []
for result in self.generate(
llm_input=llm_input,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
):
output_tokens.append(result.token)
return llm_input.tokens, output_tokens
def sample_top_p(probs, p): def sample_top_p(probs, p):

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import math import math
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
@ -184,7 +174,6 @@ class Attention(nn.Module):
self.head_dim, self.head_dim,
) )
).cuda() ).cuda()
self.qk_norm = None self.qk_norm = None
if self.use_qk_norm: if self.use_qk_norm:
self.qk_norm = L2Norm(args.norm_eps) self.qk_norm = L2Norm(args.norm_eps)

View file

@ -100,31 +100,21 @@ class Experts(nn.Module):
class MoE(torch.nn.Module): class MoE(torch.nn.Module):
""" """
This EC implementation is modified from the original EC module.
We refactored the token permutation and unpermutation logic and added support to tp and dp2ep sharding.
This module supports 3 sharding methods of the experts:
- tp: each TP rank has n_experts experts. Experts are sharded following the conventional row/column-parallel TP sharding.
- tp2ep: each TP rank has n_experts/tp experts. Experts are not sharded.
- dp2ep: each EP rank has n_experts/ep experts. Experts are sharded following the row/column-parallel TP sharding.
Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor. Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
Several commonly used annotations include: Several commonly used annotations include:
- a: bsz*slen - a: bsz*slen
- E: number of experts - E: number of experts
- e: number of local experts per ep (n_experts/ep) - e: number of local experts per ep (n_experts/ep)
- et: number of local experts per tp (n_experts/tp)
- D: hidden dimension - D: hidden dimension
- d: D/tp - d: D/tp
- F: model dimension - F: model dimension
- f: F/tp (used in column/row-parallel linear)
- G: number of tokens per expert (a * capacity_factor / E) - G: number of tokens per expert (a * capacity_factor / E)
- g: number of tokens per expert per TP rank (i.e., G/TP) - g: number of tokens per expert per TP rank (i.e., G/TP)
- GG: G*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=False)
- gg: g*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=True)
Examples: Examples:
x_aD [a, D] x_aD [a, D]
routed_in_etG_D [et*G, D] routed_in_etG_D [et*G, D]
x_eGGD: [e, GG, D] x_eGD: [e, G, D]
""" """
def __init__( def __init__(
@ -207,13 +197,13 @@ class MoE(torch.nn.Module):
routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1) routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
out_aD = self.shared_expert(x_aD) out_aD = self.shared_expert(x_aD)
routed_out_egg_D = self.experts(routed_in_EG_D.detach()) routed_out_eg_D = self.experts(routed_in_EG_D.detach())
router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D) router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
out_aD.scatter_add_( out_aD.scatter_add_(
dim=0, dim=0,
index=router_indices_EG_D, index=router_indices_EG_D,
src=routed_out_egg_D.view(-1, D), src=routed_out_eg_D.view(-1, D),
) )
out_aD = reduce_from_model_parallel_region(out_aD) out_aD = reduce_from_model_parallel_region(out_aD)
return out_aD.view(-1, slen, D) return out_aD.view(-1, slen, D)

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap import textwrap
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path

View file

@ -6,20 +6,29 @@
import logging import logging
import os import os
from typing import Optional from typing import Callable, Optional
import torch import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
from torch import Tensor from torch import Tensor, nn
from torch.nn import functional as F from torch.nn import functional as F
from ..generation import QuantizationMode from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock from ..model import Transformer, TransformerBlock
from ..moe import MoE from ..moe import MoE
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def swiglu_wrapper_no_reduce(
self,
x: Tensor,
):
from ...quantize_impls import ffn_swiglu
return ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
def experts_batched_swiglu_wrapper( def experts_batched_swiglu_wrapper(
self, self,
x: Tensor, # (e, g, D) x: Tensor, # (e, g, D)
@ -51,24 +60,30 @@ def convert_to_quantized_model(
rank = get_model_parallel_rank() rank = get_model_parallel_rank()
def should_quantize_block(block: nn.Module) -> bool:
if not isinstance(block, TransformerBlock):
return False
is_moe = isinstance(block.feed_forward, MoE)
if quantization_mode == QuantizationMode.fp8_mixed:
# skip quantization on first and last layers
return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
return is_moe
use_rich_progress = use_rich_progress and rank == 0 use_rich_progress = use_rich_progress and rank == 0
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model) progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block)
if quantization_mode == QuantizationMode.int4_mixed: if quantization_mode == QuantizationMode.int4_mixed:
int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt") int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
int4_zero_points_path = os.path.join(checkpoint_dir, f"int4_zero_points_{rank}.pt")
if os.path.isfile(int4_scales_path): if os.path.isfile(int4_scales_path):
log_status(f"Rank {rank}: Loading int4 scales") log_status(f"Rank {rank}: Loading int4 scales")
int4_scales = torch.load(int4_scales_path, weights_only=True) int4_scales = torch.load(int4_scales_path, weights_only=True)
int4_zero_points = torch.load(int4_zero_points_path, weights_only=True)
def apply_quantization(key, weight): def apply_quantization(key, weight):
scale = int4_scales[key] scale = int4_scales[key]
zero_point = int4_zero_points[key]
return load_int4( return load_int4(
weight, weight,
scale, scale,
zero_point,
fp8_activation_scale_ub,
output_device=torch.device("cuda"), output_device=torch.device("cuda"),
) )
@ -76,7 +91,8 @@ def convert_to_quantized_model(
log_status(f"Rank {rank}: Quantizing int4 weights from bf16") log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
def apply_quantization(_, weight): def apply_quantization(_, weight):
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda")) return quantize_int4(weight, output_device=torch.device("cuda"))
else: else:
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt") fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
if os.path.isfile(fp8_scales_path): if os.path.isfile(fp8_scales_path):
@ -104,33 +120,38 @@ def convert_to_quantized_model(
progress.start() progress.start()
for _, block in model.named_modules(): for _, block in model.named_modules():
if isinstance(block, TransformerBlock): if not should_quantize_block(block):
# Skip quantization on first and last layers continue
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
# Skip quantization on dense layers update_status(f"Rank {rank} - Layer {block.layer_id}")
if not isinstance(block.feed_forward, MoE):
continue
update_status(f"Rank {rank} - Layer {block.layer_id}") # Quantize only routed experts, not shared
prefix = f"layers.{block.layer_id}.feed_forward"
moe = block.feed_forward
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
# Quantize only routed experts, not shared for key in ("w1", "w3", "w2"):
prefix = f"layers.{block.layer_id}.feed_forward" param = getattr(moe.experts, key)
moe = block.feed_forward update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts) setattr(
moe.experts,
key,
apply_quantization(
f"{prefix}.experts.{key}",
param.transpose(1, 2).contiguous(),
),
)
if quantization_mode == QuantizationMode.int4_mixed:
# Quantize shared experts
moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert)
for key in ("w1", "w3", "w2"): for key in ("w1", "w3", "w2"):
param = getattr(moe.experts, key) param = getattr(moe.shared_expert, key)
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}") update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}")
setattr( param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight)
moe.experts,
key,
apply_quantization(f"{prefix}.experts.{key}", param.transpose(1, 2).contiguous()),
)
processed_blocks += 1 processed_blocks += 1
update_status(message=None, completed=processed_blocks) update_status(message=None, completed=processed_blocks)
update_status(f"Rank {rank} - Moving parameters to CUDA") update_status(f"Rank {rank} - Moving parameters to CUDA")
@ -149,7 +170,12 @@ def convert_to_quantized_model(
# fp8/int4 loading can be very slow so we add progress bars to make life slightly better # fp8/int4 loading can be very slow so we add progress bars to make life slightly better
def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer): def logging_callbacks(
use_rich_progress: bool,
rank: int,
model: Transformer,
should_quantize_block: Callable[[nn.Module], bool],
):
console = None console = None
if use_rich_progress: if use_rich_progress:
from rich.console import Console from rich.console import Console
@ -162,15 +188,7 @@ def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer):
elif rank == 0: # Only log from rank 0 for non-rich logging elif rank == 0: # Only log from rank 0 for non-rich logging
log.info(message) log.info(message)
total_blocks = sum( total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
1
for _, block in model.named_modules()
if (
isinstance(block, TransformerBlock)
and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
and isinstance(block.feed_forward, MoE)
)
)
progress = None progress = None
if use_rich_progress: if use_rich_progress:
from rich.progress import ( from rich.progress import (

View file

@ -4,6 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
@ -59,8 +66,6 @@ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
"<|text_post_train_reserved_special_token_3|>", "<|text_post_train_reserved_special_token_3|>",
"<|text_post_train_reserved_special_token_4|>", "<|text_post_train_reserved_special_token_4|>",
"<|text_post_train_reserved_special_token_5|>", "<|text_post_train_reserved_special_token_5|>",
"<|python_start|>",
"<|python_end|>",
"<|finetune_right_pad|>", "<|finetune_right_pad|>",
] + get_reserved_special_tokens( ] + get_reserved_special_tokens(
"text_post_train", 61, 6 "text_post_train", 61, 6
@ -85,8 +90,23 @@ LLAMA4_VISION_SPECIAL_TOKENS = [
"vision", 1041, 7 "vision", 1041, 7
) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|> ) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|>
# 201134, ..., 201143
LLAMA4_REASONING_SPECIAL_TOKENS = [
"<|reasoning_reserved_special_token_0|>",
"<|reasoning_reserved_special_token_1|>",
"<|reasoning_reserved_special_token_2|>",
"<|reasoning_reserved_special_token_3|>",
"<|reasoning_reserved_special_token_4|>",
"<|reasoning_reserved_special_token_5|>",
"<|reasoning_reserved_special_token_6|>",
"<|reasoning_reserved_special_token_7|>",
"<|reasoning_thinking_start|>",
"<|reasoning_thinking_end|>",
]
LLAMA4_SPECIAL_TOKENS = LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS LLAMA4_SPECIAL_TOKENS = (
LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS + LLAMA4_REASONING_SPECIAL_TOKENS
)
BASIC_SPECIAL_TOKENS = [ BASIC_SPECIAL_TOKENS = [
"<|begin_of_text|>", "<|begin_of_text|>",
@ -155,6 +175,9 @@ class Tokenizer:
self.eot_id: int = self.special_tokens["<|eot|>"] self.eot_id: int = self.special_tokens["<|eot|>"]
self.eom_id: int = self.special_tokens["<|eom|>"] self.eom_id: int = self.special_tokens["<|eom|>"]
self.thinking_start_id: int = self.special_tokens["<|reasoning_thinking_start|>"]
self.thinking_end_id: int = self.special_tokens["<|reasoning_thinking_end|>"]
self.stop_tokens = [ self.stop_tokens = [
self.eos_id, self.eos_id,
self.special_tokens["<|eom|>"], self.special_tokens["<|eom|>"],

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import math import math
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List

View file

@ -28,9 +28,6 @@ from llama_stack.models.llama.datatypes import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.models.llama.llama4.tokenizer import Tokenizer from llama_stack.models.llama.llama4.tokenizer import Tokenizer
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
LLMInput,
)
from .llama3.interface import LLama31Interface from .llama3.interface import LLama31Interface
from .llama3.template_data import ( from .llama3.template_data import (
@ -38,6 +35,7 @@ from .llama3.template_data import (
system_message_builtin_tools_only, system_message_builtin_tools_only,
system_message_custom_tools_only, system_message_custom_tools_only,
) )
from .llama4.datatypes import LLMInput
class TextCompletionContent(BaseModel): class TextCompletionContent(BaseModel):

View file

@ -4,24 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import List, Optional from typing import List, Optional
from .datatypes import ( from .sku_types import (
CheckpointQuantizationFormat, CheckpointQuantizationFormat,
CoreModelId, CoreModelId,
Model, Model,
ModelFamily, ModelFamily,
SamplingParams,
TopPSamplingStrategy,
) )
LLAMA2_VOCAB_SIZE = 32000 LLAMA2_VOCAB_SIZE = 32000
@ -47,15 +38,6 @@ def all_registered_models() -> List[Model]:
) )
def recommended_sampling_params() -> SamplingParams:
return SamplingParams(
strategy=TopPSamplingStrategy(
temperature=1.0,
top_p=0.9,
)
)
def llama2_family() -> List[Model]: def llama2_family() -> List[Model]:
return [ return [
*llama2_base_models(), *llama2_base_models(),
@ -150,7 +132,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_7b, core_model_id=CoreModelId.llama2_7b,
description="Llama 2 7b model", description="Llama 2 7b model",
huggingface_repo="meta-llama/Llama-2-7b", huggingface_repo="meta-llama/Llama-2-7b",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 4096, "dim": 4096,
"n_layers": 32, "n_layers": 32,
@ -169,7 +150,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_13b, core_model_id=CoreModelId.llama2_13b,
description="Llama 2 13b model", description="Llama 2 13b model",
huggingface_repo="meta-llama/Llama-2-13b", huggingface_repo="meta-llama/Llama-2-13b",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 5120, "dim": 5120,
"n_layers": 40, "n_layers": 40,
@ -188,7 +168,6 @@ def llama2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama2_70b, core_model_id=CoreModelId.llama2_70b,
description="Llama 2 70b model", description="Llama 2 70b model",
huggingface_repo="meta-llama/Llama-2-70b", huggingface_repo="meta-llama/Llama-2-70b",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -230,7 +209,6 @@ def llama3_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_70b, core_model_id=CoreModelId.llama3_70b,
description="Llama 3 70b model", description="Llama 3 70b model",
huggingface_repo="meta-llama/Llama-3-70B", huggingface_repo="meta-llama/Llama-3-70B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -254,7 +232,6 @@ def llama3_1_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_8b, core_model_id=CoreModelId.llama3_1_8b,
description="Llama 3.1 8b model", description="Llama 3.1 8b model",
huggingface_repo="meta-llama/Llama-3.1-8B", huggingface_repo="meta-llama/Llama-3.1-8B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 4096, "dim": 4096,
"n_layers": 32, "n_layers": 32,
@ -273,7 +250,6 @@ def llama3_1_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_70b, core_model_id=CoreModelId.llama3_1_70b,
description="Llama 3.1 70b model", description="Llama 3.1 70b model",
huggingface_repo="meta-llama/Llama-3.1-70B", huggingface_repo="meta-llama/Llama-3.1-70B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -293,7 +269,6 @@ def llama3_1_base_models() -> List[Model]:
variant="bf16-mp8", variant="bf16-mp8",
description="Llama 3.1 405b model (BF16 weights)", description="Llama 3.1 405b model (BF16 weights)",
huggingface_repo="meta-llama/Llama-3.1-405B", huggingface_repo="meta-llama/Llama-3.1-405B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 16384, "dim": 16384,
"n_layers": 126, "n_layers": 126,
@ -313,7 +288,6 @@ def llama3_1_base_models() -> List[Model]:
description="Llama 3.1 405b model (FP8 quantized)", description="Llama 3.1 405b model (FP8 quantized)",
huggingface_repo="meta-llama/Llama-3.1-405B-FP8", huggingface_repo="meta-llama/Llama-3.1-405B-FP8",
quantization_format=CheckpointQuantizationFormat.fp8_mixed, quantization_format=CheckpointQuantizationFormat.fp8_mixed,
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 16384, "dim": 16384,
"n_layers": 126, "n_layers": 126,
@ -333,7 +307,6 @@ def llama3_1_base_models() -> List[Model]:
variant="bf16-mp16", variant="bf16-mp16",
description="Llama 3.1 405b model (BF16 weights for mp16)", description="Llama 3.1 405b model (BF16 weights for mp16)",
huggingface_repo="meta-llama/Llama-3.1-405B", huggingface_repo="meta-llama/Llama-3.1-405B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 16384, "dim": 16384,
"n_layers": 126, "n_layers": 126,
@ -357,7 +330,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_1b, core_model_id=CoreModelId.llama3_2_1b,
description="Llama 3.2 1b model", description="Llama 3.2 1b model",
huggingface_repo="meta-llama/Llama-3.2-1B", huggingface_repo="meta-llama/Llama-3.2-1B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 2048, "dim": 2048,
"n_layers": 16, "n_layers": 16,
@ -376,7 +348,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_3b, core_model_id=CoreModelId.llama3_2_3b,
description="Llama 3.2 3b model", description="Llama 3.2 3b model",
huggingface_repo="meta-llama/Llama-3.2-3B", huggingface_repo="meta-llama/Llama-3.2-3B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 3072, "dim": 3072,
"n_layers": 28, "n_layers": 28,
@ -395,7 +366,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_11b_vision, core_model_id=CoreModelId.llama3_2_11b_vision,
description="Llama 3.2 11b vision model", description="Llama 3.2 11b vision model",
huggingface_repo="meta-llama/Llama-3.2-11B-Vision", huggingface_repo="meta-llama/Llama-3.2-11B-Vision",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 4096, "dim": 4096,
"n_layers": 32, "n_layers": 32,
@ -417,7 +387,6 @@ def llama3_2_base_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_90b_vision, core_model_id=CoreModelId.llama3_2_90b_vision,
description="Llama 3.2 90b vision model", description="Llama 3.2 90b vision model",
huggingface_repo="meta-llama/Llama-3.2-90B-Vision", huggingface_repo="meta-llama/Llama-3.2-90B-Vision",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -444,7 +413,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_7b_chat, core_model_id=CoreModelId.llama2_7b_chat,
description="Llama 2 7b chat model", description="Llama 2 7b chat model",
huggingface_repo="meta-llama/Llama-2-7b-chat", huggingface_repo="meta-llama/Llama-2-7b-chat",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 4096, "dim": 4096,
"n_layers": 32, "n_layers": 32,
@ -463,7 +431,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_13b_chat, core_model_id=CoreModelId.llama2_13b_chat,
description="Llama 2 13b chat model", description="Llama 2 13b chat model",
huggingface_repo="meta-llama/Llama-2-13b-chat", huggingface_repo="meta-llama/Llama-2-13b-chat",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 5120, "dim": 5120,
"n_layers": 40, "n_layers": 40,
@ -482,7 +449,6 @@ def llama2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama2_70b_chat, core_model_id=CoreModelId.llama2_70b_chat,
description="Llama 2 70b chat model", description="Llama 2 70b chat model",
huggingface_repo="meta-llama/Llama-2-70b-chat", huggingface_repo="meta-llama/Llama-2-70b-chat",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -506,7 +472,6 @@ def llama3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_8b_instruct, core_model_id=CoreModelId.llama3_8b_instruct,
description="Llama 3 8b instruct model", description="Llama 3 8b instruct model",
huggingface_repo="meta-llama/Llama-3-8B-Instruct", huggingface_repo="meta-llama/Llama-3-8B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 4096, "dim": 4096,
"n_layers": 32, "n_layers": 32,
@ -525,7 +490,6 @@ def llama3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_70b_instruct, core_model_id=CoreModelId.llama3_70b_instruct,
description="Llama 3 70b instruct model", description="Llama 3 70b instruct model",
huggingface_repo="meta-llama/Llama-3-70B-Instruct", huggingface_repo="meta-llama/Llama-3-70B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -549,7 +513,6 @@ def llama3_1_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_8b_instruct, core_model_id=CoreModelId.llama3_1_8b_instruct,
description="Llama 3.1 8b instruct model", description="Llama 3.1 8b instruct model",
huggingface_repo="meta-llama/Llama-3.1-8B-Instruct", huggingface_repo="meta-llama/Llama-3.1-8B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 4096, "dim": 4096,
"n_layers": 32, "n_layers": 32,
@ -568,7 +531,6 @@ def llama3_1_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_1_70b_instruct, core_model_id=CoreModelId.llama3_1_70b_instruct,
description="Llama 3.1 70b instruct model", description="Llama 3.1 70b instruct model",
huggingface_repo="meta-llama/Llama-3.1-70B-Instruct", huggingface_repo="meta-llama/Llama-3.1-70B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -588,7 +550,6 @@ def llama3_1_instruct_models() -> List[Model]:
variant="bf16-mp8", variant="bf16-mp8",
description="Llama 3.1 405b instruct model (BF16 weights)", description="Llama 3.1 405b instruct model (BF16 weights)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct", huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 16384, "dim": 16384,
"n_layers": 126, "n_layers": 126,
@ -608,7 +569,6 @@ def llama3_1_instruct_models() -> List[Model]:
description="Llama 3.1 405b instruct model (FP8 quantized)", description="Llama 3.1 405b instruct model (FP8 quantized)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8", huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8",
quantization_format=CheckpointQuantizationFormat.fp8_mixed, quantization_format=CheckpointQuantizationFormat.fp8_mixed,
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 16384, "dim": 16384,
"n_layers": 126, "n_layers": 126,
@ -628,7 +588,6 @@ def llama3_1_instruct_models() -> List[Model]:
variant="bf16-mp16", variant="bf16-mp16",
description="Llama 3.1 405b instruct model (BF16 weights for mp16)", description="Llama 3.1 405b instruct model (BF16 weights for mp16)",
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct", huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 16384, "dim": 16384,
"n_layers": 126, "n_layers": 126,
@ -684,7 +643,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4, quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 1b INT4 quantized LoRA", description="Llama 3.2 1b INT4 quantized LoRA",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8", huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
**arch_args_1b(), **arch_args_1b(),
"quantization_args": { "quantization_args": {
@ -703,7 +661,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4, quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 1b INT4 quantized SpinQuant", description="Llama 3.2 1b INT4 quantized SpinQuant",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8", huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
**arch_args_1b(), **arch_args_1b(),
"quantization_args": { "quantization_args": {
@ -718,7 +675,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4, quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 3b INT4 quantized LoRA", description="Llama 3.2 3b INT4 quantized LoRA",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8", huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
**arch_args_3b(), **arch_args_3b(),
"quantization_args": { "quantization_args": {
@ -737,7 +693,6 @@ def llama3_2_quantized_models() -> List[Model]:
quantization_format=CheckpointQuantizationFormat.int4, quantization_format=CheckpointQuantizationFormat.int4,
description="Llama 3.2 3b INT4 quantized SpinQuant", description="Llama 3.2 3b INT4 quantized SpinQuant",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8", huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
**arch_args_3b(), **arch_args_3b(),
"quantization_args": { "quantization_args": {
@ -755,7 +710,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_1b_instruct, core_model_id=CoreModelId.llama3_2_1b_instruct,
description="Llama 3.2 1b instruct model", description="Llama 3.2 1b instruct model",
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct", huggingface_repo="meta-llama/Llama-3.2-1B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args=arch_args_1b(), arch_args=arch_args_1b(),
pth_file_count=1, pth_file_count=1,
), ),
@ -763,7 +717,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_3b_instruct, core_model_id=CoreModelId.llama3_2_3b_instruct,
description="Llama 3.2 3b instruct model", description="Llama 3.2 3b instruct model",
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct", huggingface_repo="meta-llama/Llama-3.2-3B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args=arch_args_3b(), arch_args=arch_args_3b(),
pth_file_count=1, pth_file_count=1,
), ),
@ -772,7 +725,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_11b_vision_instruct, core_model_id=CoreModelId.llama3_2_11b_vision_instruct,
description="Llama 3.2 11b vision instruct model", description="Llama 3.2 11b vision instruct model",
huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct", huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 4096, "dim": 4096,
"n_layers": 32, "n_layers": 32,
@ -794,7 +746,6 @@ def llama3_2_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_2_90b_vision_instruct, core_model_id=CoreModelId.llama3_2_90b_vision_instruct,
description="Llama 3.2 90b vision instruct model", description="Llama 3.2 90b vision instruct model",
huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct", huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -821,7 +772,6 @@ def llama3_3_instruct_models() -> List[Model]:
core_model_id=CoreModelId.llama3_3_70b_instruct, core_model_id=CoreModelId.llama3_3_70b_instruct,
description="Llama 3.3 70b instruct", description="Llama 3.3 70b instruct",
huggingface_repo="meta-llama/Llama-3.3-70B-Instruct", huggingface_repo="meta-llama/Llama-3.3-70B-Instruct",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 8192, "dim": 8192,
"n_layers": 80, "n_layers": 80,
@ -846,7 +796,6 @@ def safety_models() -> List[Model]:
core_model_id=CoreModelId.llama_guard_3_11b_vision, core_model_id=CoreModelId.llama_guard_3_11b_vision,
description="Llama Guard v3 11b vision system safety model", description="Llama Guard v3 11b vision system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision", huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 4096, "dim": 4096,
"n_layers": 32, "n_layers": 32,
@ -870,7 +819,6 @@ def safety_models() -> List[Model]:
description="Llama Guard v3 1b 'int4' quantized system safety model", description="Llama Guard v3 1b 'int4' quantized system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4", huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4",
quantization_format=CheckpointQuantizationFormat.int4, quantization_format=CheckpointQuantizationFormat.int4,
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 2048, "dim": 2048,
"n_layers": 12, "n_layers": 12,
@ -888,7 +836,6 @@ def safety_models() -> List[Model]:
core_model_id=CoreModelId.llama_guard_3_1b, core_model_id=CoreModelId.llama_guard_3_1b,
description="Llama Guard v3 1b system safety model", description="Llama Guard v3 1b system safety model",
huggingface_repo="meta-llama/Llama-Guard-3-1B", huggingface_repo="meta-llama/Llama-Guard-3-1B",
recommended_sampling_params=recommended_sampling_params(),
arch_args={ arch_args={
"dim": 2048, "dim": 2048,
"n_layers": 16, "n_layers": 16,

View file

@ -0,0 +1,229 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict, Field
class CheckpointQuantizationFormat(Enum):
# default format
bf16 = "bf16"
# used for enabling fp8_rowwise inference, some weights are bf16
fp8_mixed = "fp8-mixed"
int8 = "int8"
int4 = "int4"
class ModelFamily(Enum):
llama2 = "llama2"
llama3 = "llama3"
llama3_1 = "llama3_1"
llama3_2 = "llama3_2"
llama3_3 = "llama3_3"
llama4 = "llama4"
safety = "safety"
class CoreModelId(Enum):
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
# Llama 2 family
llama2_7b = "Llama-2-7b"
llama2_13b = "Llama-2-13b"
llama2_70b = "Llama-2-70b"
llama2_7b_chat = "Llama-2-7b-chat"
llama2_13b_chat = "Llama-2-13b-chat"
llama2_70b_chat = "Llama-2-70b-chat"
# Llama 3 family
llama3_8b = "Llama-3-8B"
llama3_70b = "Llama-3-70B"
llama3_8b_instruct = "Llama-3-8B-Instruct"
llama3_70b_instruct = "Llama-3-70B-Instruct"
# Llama 3.1 family
llama3_1_8b = "Llama3.1-8B"
llama3_1_70b = "Llama3.1-70B"
llama3_1_405b = "Llama3.1-405B"
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
# Llama 3.2 family
llama3_2_1b = "Llama3.2-1B"
llama3_2_3b = "Llama3.2-3B"
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
llama3_2_11b_vision = "Llama3.2-11B-Vision"
llama3_2_90b_vision = "Llama3.2-90B-Vision"
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
# Llama 3.3 family
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
# Llama 4 family
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
# Safety models
llama_guard_3_8b = "Llama-Guard-3-8B"
llama_guard_2_8b = "Llama-Guard-2-8B"
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
llama_guard_3_1b = "Llama-Guard-3-1B"
def is_multimodal(model_id) -> bool:
if model_id in [
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return True
else:
return False
def model_family(model_id) -> ModelFamily:
if model_id in [
CoreModelId.llama2_7b,
CoreModelId.llama2_13b,
CoreModelId.llama2_70b,
CoreModelId.llama2_7b_chat,
CoreModelId.llama2_13b_chat,
CoreModelId.llama2_70b_chat,
]:
return ModelFamily.llama2
elif model_id in [
CoreModelId.llama3_8b,
CoreModelId.llama3_70b,
CoreModelId.llama3_8b_instruct,
CoreModelId.llama3_70b_instruct,
]:
return ModelFamily.llama3
elif model_id in [
CoreModelId.llama3_1_8b,
CoreModelId.llama3_1_70b,
CoreModelId.llama3_1_405b,
CoreModelId.llama3_1_8b_instruct,
CoreModelId.llama3_1_70b_instruct,
CoreModelId.llama3_1_405b_instruct,
]:
return ModelFamily.llama3_1
elif model_id in [
CoreModelId.llama3_2_1b,
CoreModelId.llama3_2_3b,
CoreModelId.llama3_2_1b_instruct,
CoreModelId.llama3_2_3b_instruct,
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return ModelFamily.llama3_2
elif model_id in [
CoreModelId.llama3_3_70b_instruct,
]:
return ModelFamily.llama3_3
elif model_id in [
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_scout_17b_16e_instruct,
CoreModelId.llama4_maverick_17b_128e,
CoreModelId.llama4_maverick_17b_128e_instruct,
]:
return ModelFamily.llama4
elif model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_2_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return ModelFamily.safety
else:
raise ValueError(f"Unknown model family for {model_id}")
class Model(BaseModel):
core_model_id: CoreModelId
description: str
huggingface_repo: Optional[str] = None
arch_args: Dict[str, Any]
variant: str = ""
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
pth_file_count: int
metadata: Dict[str, Any] = Field(default_factory=dict)
# silence pydantic until we remove the `model_` fields
model_config = ConfigDict(protected_namespaces=())
@property
def model_family(self) -> ModelFamily:
return model_family(self.core_model_id)
# The SKU is uniquely identified by (model_id, variant) combo
def descriptor(self, shorten_default_variant: bool = True) -> str:
if not self.variant:
return self.core_model_id.value
return f"{self.core_model_id.value}:{self.variant}"
@property
def is_instruct_model(self) -> bool:
return "instruct" in self.core_model_id.value
# Featured models are shown in the non-exhaustive model list
@property
def is_featured(self) -> bool:
return self.model_family in [
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
ModelFamily.safety,
]
@property
def max_seq_length(self) -> int:
if self.model_family == ModelFamily.llama2:
return 4096
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
return 4096
elif self.model_family == ModelFamily.llama3:
return 8192
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
return 131072
elif self.model_family == ModelFamily.llama3_2:
if self.quantization_format == CheckpointQuantizationFormat.int4:
return 8192
return 131072
elif self.model_family == ModelFamily.llama4:
if self.core_model_id in {
CoreModelId.llama4_scout_17b_16e,
CoreModelId.llama4_maverick_17b_128e,
}:
return 262144
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
return 10485760
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
return 1048576
raise AssertionError(f"Unexpected core model id: {self.core_model_id}")
elif self.core_model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return 131072
else:
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")

View file

@ -52,6 +52,7 @@ from llama_stack.apis.inference import (
StopReason, StopReason,
SystemMessage, SystemMessage,
ToolDefinition, ToolDefinition,
ToolParamDefinition,
ToolResponse, ToolResponse,
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
@ -63,7 +64,6 @@ from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
ToolCall, ToolCall,
ToolParamDefinition,
) )
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.telemetry import tracing from llama_stack.providers.utils.telemetry import tracing

View file

@ -12,20 +12,19 @@ from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerToken
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Fp8QuantizationConfig, Fp8QuantizationConfig,
GreedySamplingStrategy,
Int4QuantizationConfig, Int4QuantizationConfig,
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
ResponseFormat, ResponseFormat,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
Model,
SamplingParams, SamplingParams,
TopPSamplingStrategy, TopPSamplingStrategy,
) )
from llama_stack.models.llama.datatypes import QuantizationMode
from llama_stack.models.llama.llama3.generation import Llama3 from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4 from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_types import Model
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent, ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent, CompletionRequestWithRawContent,
@ -136,9 +135,9 @@ class Llama4Generator:
if isinstance(config, MetaReferenceQuantizedInferenceConfig): if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig): if isinstance(config.quantization, Fp8QuantizationConfig):
quantization_mode = "fp8_mixed" quantization_mode = QuantizationMode.fp8_mixed
elif isinstance(config.quantization, Int4QuantizationConfig): elif isinstance(config.quantization, Int4QuantizationConfig):
quantization_mode = "int4_mixed" quantization_mode = QuantizationMode.int4_mixed
else: else:
raise ValueError(f"Unsupported quantization mode {config.quantization}") raise ValueError(f"Unsupported quantization mode {config.quantization}")
else: else:
@ -225,9 +224,9 @@ class Llama3Generator:
if isinstance(config, MetaReferenceQuantizedInferenceConfig): if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig): if isinstance(config.quantization, Fp8QuantizationConfig):
quantization_mode = "fp8_mixed" quantization_mode = QuantizationMode.fp8_mixed
elif isinstance(config.quantization, Int4QuantizationConfig): elif isinstance(config.quantization, Int4QuantizationConfig):
quantization_mode = "int4_mixed" quantization_mode = QuantizationMode.int4_mixed
else: else:
raise ValueError(f"Unsupported quantization mode {config.quantization}") raise ValueError(f"Unsupported quantization mode {config.quantization}")
else: else:
@ -240,6 +239,9 @@ class Llama3Generator:
world_size=llama_model.pth_file_count, world_size=llama_model.pth_file_count,
quantization_mode=quantization_mode, quantization_mode=quantization_mode,
) )
self.tokenizer = self.inner_generator.tokenizer
self.args = self.inner_generator.args
self.formatter = self.inner_generator.formatter
def completion( def completion(
self, self,

View file

@ -31,23 +31,21 @@ from llama_stack.apis.inference import (
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams,
StopReason,
TokenLogProbs, TokenLogProbs,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import (
ModelFamily,
SamplingParams,
StopReason,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,

View file

@ -14,9 +14,10 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
Message, Message,
ToolChoice, ToolChoice,
ToolDefinition,
UserMessage, UserMessage,
) )
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict, convert_message_to_openai_dict,
get_sampling_options, get_sampling_options,

View file

@ -46,6 +46,8 @@ from llama_stack.apis.inference import (
TokenLogProbs, TokenLogProbs,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
TopKSamplingStrategy,
TopPSamplingStrategy,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -55,8 +57,6 @@ from llama_stack.models.llama.datatypes import (
ToolCall, ToolCall,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
TopKSamplingStrategy,
TopPSamplingStrategy,
) )
from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer

View file

@ -22,8 +22,8 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform from torchtune.modules.transforms import Transform
from llama_stack.apis.post_training import DatasetFormat from llama_stack.apis.post_training import DatasetFormat
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import Model
BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer] BuildTokenizerCallable = Callable[..., Llama3Tokenizer]

View file

@ -23,7 +23,8 @@ from llama_stack.apis.safety import (
) )
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api
from llama_stack.models.llama.datatypes import CoreModelId, Role from llama_stack.models.llama.datatypes import Role
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )

View file

@ -28,8 +28,8 @@ from llama_stack.apis.inference import (
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
TopKSamplingStrategy,
) )
from llama_stack.models.llama.datatypes import TopKSamplingStrategy
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )

View file

@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
build_hf_repo_model_entry, build_hf_repo_model_entry,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.models.models import ModelType from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry, ProviderModelEntry,
build_hf_repo_model_entry, build_hf_repo_model_entry,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry, ProviderModelEntry,
build_hf_repo_model_entry, build_hf_repo_model_entry,

View file

@ -29,15 +29,13 @@ from llama_stack.apis.inference import (
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams,
TextTruncation, TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
)
from llama_stack.models.llama.datatypes import (
SamplingParams,
ToolDefinition, ToolDefinition,
ToolPromptFormat,
) )
from llama_stack.models.llama.datatypes import ToolPromptFormat
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )

View file

@ -19,11 +19,9 @@ from llama_stack.apis.inference import (
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
GreedySamplingStrategy,
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
TokenLogProbs, TokenLogProbs,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy, TopKSamplingStrategy,
TopPSamplingStrategy, TopPSamplingStrategy,
) )

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.models.models import ModelType from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry, ProviderModelEntry,
build_hf_repo_model_entry, build_hf_repo_model_entry,

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )

View file

@ -21,6 +21,7 @@ from llama_stack.apis.inference import (
CompletionMessage, CompletionMessage,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
GreedySamplingStrategy,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
@ -35,12 +36,9 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
ToolResponseMessage, ToolResponseMessage,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy, TopKSamplingStrategy,
TopPSamplingStrategy, TopPSamplingStrategy,
UserMessage,
) )
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.models.models import ModelType from llama_stack.apis.models.models import ModelType
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry, ProviderModelEntry,
build_hf_repo_model_entry, build_hf_repo_model_entry,

View file

@ -96,7 +96,6 @@ def _convert_to_vllm_tool_calls_in_response(
call_id=call.id, call_id=call.id,
tool_name=call.function.name, tool_name=call.function.name,
arguments=json.loads(call.function.arguments), arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
) )
for call in tool_calls for call in tool_calls
] ]
@ -176,7 +175,6 @@ async def _process_vllm_chat_completion_stream_response(
call_id=tool_call_buf.call_id, call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name, tool_name=tool_call_buf.tool_name,
arguments=args, arguments=args,
arguments_json=args_str,
), ),
parse_status=ToolCallParseStatus.succeeded, parse_status=ToolCallParseStatus.succeeded,
), ),

View file

@ -6,7 +6,7 @@
from typing import List from typing import List
from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry, ProviderModelEntry,
build_hf_repo_model_entry, build_hf_repo_model_entry,

View file

@ -12,8 +12,8 @@ import pytest
from pytest import ExitCode from pytest import ExitCode
from pytest_html.basereport import _process_outcome from pytest_html.basereport import _process_outcome
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.models.llama.sku_types import CoreModelId
INFERENCE_APIS = ["chat_completion"] INFERENCE_APIS = ["chat_completion"]
FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"] FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"]

View file

@ -6,8 +6,8 @@
from typing import List from typing import List
from llama_stack.models.llama.datatypes import * # noqa: F403
from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.models.llama.sku_types import * # noqa: F403
def is_supported_safety_model(model: Model) -> bool: def is_supported_safety_model(model: Model) -> bool:

View file

@ -73,21 +73,21 @@ from llama_stack.apis.inference import (
CompletionMessage, CompletionMessage,
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
GreedySamplingStrategy,
Message, Message,
SamplingParams,
SystemMessage, SystemMessage,
TokenLogProbs, TokenLogProbs,
ToolResponseMessage, ToolResponseMessage,
TopKSamplingStrategy,
TopPSamplingStrategy,
UserMessage, UserMessage,
) )
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
BuiltinTool, BuiltinTool,
GreedySamplingStrategy,
SamplingParams,
StopReason, StopReason,
ToolCall, ToolCall,
ToolDefinition, ToolDefinition,
TopKSamplingStrategy,
TopPSamplingStrategy,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_content_to_url, convert_image_content_to_url,

View file

@ -34,7 +34,6 @@ from llama_stack.apis.inference import (
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
ModelFamily,
RawContent, RawContent,
RawContentItem, RawContentItem,
RawMediaItem, RawMediaItem,
@ -43,7 +42,6 @@ from llama_stack.models.llama.datatypes import (
Role, Role,
StopReason, StopReason,
ToolPromptFormat, ToolPromptFormat,
is_multimodal,
) )
from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.prompt_templates import ( from llama_stack.models.llama.llama3.prompt_templates import (
@ -55,6 +53,7 @@ from llama_stack.models.llama.llama3.prompt_templates import (
) )
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
log = get_logger(name=__name__, category="inference") log = get_logger(name=__name__, category="inference")

View file

@ -224,9 +224,9 @@ exclude = [
"^llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/config\\.py$", "^llama_stack/providers/inline/inference/meta_reference/config\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/inference\\.py$", "^llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/llama3/generation\\.py$", "^llama_stack/models/llama/llama3/generation\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/llama4/", "^llama_stack/models/llama/llama4/",
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$", "^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$",

View file

@ -11,7 +11,6 @@ import pytest
from pytest import CollectReport from pytest import CollectReport
from termcolor import cprint from termcolor import cprint
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.models.llama.sku_list import ( from llama_stack.models.llama.sku_list import (
all_registered_models, all_registered_models,
llama3_1_instruct_models, llama3_1_instruct_models,
@ -20,6 +19,7 @@ from llama_stack.models.llama.sku_list import (
llama3_instruct_models, llama3_instruct_models,
safety_models, safety_models,
) )
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from .metadata import API_MAPS from .metadata import API_MAPS