chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -7,14 +7,14 @@
import concurrent.futures
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any
import numpy as np
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> list[int]:
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
if new_mp_size % old_mp_size == 0:
# Read old MP shard and split it into smaller ones
@ -31,12 +31,12 @@ def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[in
def maybe_reshard_state_dict(
ckpt_paths: List[Path],
ckpt_paths: list[Path],
n_kv_heads: int,
moe_num_experts: Optional[int] = None,
map_location: Union[str, torch.device] = "cpu",
moe_num_experts: int | None = None,
map_location: str | torch.device = "cpu",
mmap: bool = True,
) -> Dict[str, torch.Tensor]:
) -> dict[str, torch.Tensor]:
if str(map_location) == "cpu":
torch.set_default_tensor_type(torch.BFloat16Tensor)
else:
@ -97,18 +97,18 @@ _MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
def reshard_mp(
state_dicts: List[Dict[str, torch.Tensor]],
state_dicts: list[dict[str, torch.Tensor]],
size: int,
rank: int,
repeat_qk_qv: int = 1,
) -> Dict[str, torch.Tensor]:
) -> dict[str, torch.Tensor]:
"""
Reshard a list of state dicts into a single state dict given a change in MP size.
If the list has more than one state dict, we concatenate the values of the same
key across all state dicts. Otherwise, we just slice it for the current MP rank.
"""
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
def concat_or_chunk(tensors: list[torch.Tensor], dim: int) -> torch.Tensor:
if len(tensors) > 1:
return torch.cat(tensors, dim=dim)
return tensors[0].chunk(size, dim=dim)[rank].clone()
@ -144,7 +144,7 @@ def reshard_mp(
column_regex = re.compile("|".join(column_keys))
row_regex = re.compile("|".join(row_keys))
output: Dict[str, torch.Tensor] = {}
output: dict[str, torch.Tensor] = {}
with concurrent.futures.ThreadPoolExecutor() as executor:
# Note: only processes keys in the first state dict.
# Assumes keys are the same across all state dicts.
@ -154,7 +154,7 @@ def reshard_mp(
return output
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
def convert_moe_weights(state_dict: dict[str, Any], num_experts: int) -> dict[str, Any]:
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
routed_regex = re.compile("|".join(routed_keys))
keys = list(state_dict.keys())

View file

@ -7,10 +7,9 @@
import base64
from enum import Enum
from io import BytesIO
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Annotated, Any, Literal
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
from typing_extensions import Annotated
# The goal is that these set of types are relevant for all Llama models.
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
@ -31,21 +30,21 @@ class BuiltinTool(Enum):
code_interpreter = "code_interpreter"
Primitive = Union[str, int, float, bool, None]
RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
Primitive = str | int | float | bool | None
RecursiveType = Primitive | list[Primitive] | dict[str, Primitive]
class ToolCall(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
tool_name: BuiltinTool | str
# Plan is to deprecate the Dict in favor of a JSON string
# that is parsed on the client side instead of trying to manage
# the recursive type here.
# Making this a union so that client side can start prepping for this change.
# Eventually, we will remove both the Dict and arguments_json field,
# and arguments will just be a str
arguments: Union[str, Dict[str, RecursiveType]]
arguments_json: Optional[str] = None
arguments: str | dict[str, RecursiveType]
arguments_json: str | None = None
@field_validator("tool_name", mode="before")
@classmethod
@ -91,15 +90,15 @@ class StopReason(Enum):
class ToolParamDefinition(BaseModel):
param_type: str
description: Optional[str] = None
required: Optional[bool] = True
default: Optional[Any] = None
description: str | None = None
required: bool | None = True
default: Any | None = None
class ToolDefinition(BaseModel):
tool_name: Union[BuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
tool_name: BuiltinTool | str
description: str | None = None
parameters: dict[str, ToolParamDefinition] | None = None
@field_validator("tool_name", mode="before")
@classmethod
@ -119,7 +118,7 @@ class RawMediaItem(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
@field_serializer("data")
def serialize_data(self, data: Optional[bytes], _info):
def serialize_data(self, data: bytes | None, _info):
if data is None:
return None
return base64.b64encode(data).decode("utf-8")
@ -137,9 +136,9 @@ class RawTextItem(BaseModel):
text: str
RawContentItem = Annotated[Union[RawTextItem, RawMediaItem], Field(discriminator="type")]
RawContentItem = Annotated[RawTextItem | RawMediaItem, Field(discriminator="type")]
RawContent = str | RawContentItem | List[RawContentItem]
RawContent = str | RawContentItem | list[RawContentItem]
class RawMessage(BaseModel):
@ -147,17 +146,17 @@ class RawMessage(BaseModel):
content: RawContent
# This is for RAG but likely should be absorbed into content
context: Optional[RawContent] = None
context: RawContent | None = None
# These are for the output message coming from the assistant
stop_reason: Optional[StopReason] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
stop_reason: StopReason | None = None
tool_calls: list[ToolCall] = Field(default_factory=list)
class GenerationResult(BaseModel):
token: int
text: str
logprobs: Optional[List[float]] = None
logprobs: list[float] | None = None
source: Literal["input"] | Literal["output"]

View file

@ -6,7 +6,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional
class QuantizationScheme(Enum):
@ -15,8 +14,8 @@ class QuantizationScheme(Enum):
@dataclass
class QuantizationArgs:
scheme: Optional[QuantizationScheme] = None
group_size: Optional[int] = None
scheme: QuantizationScheme | None = None
group_size: int | None = None
spinquant: bool = False
def __init__(self, **kwargs):
@ -39,10 +38,10 @@ class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
n_kv_heads: int | None = None
vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
ffn_dim_multiplier: float | None = None
norm_eps: float = 1e-5
rope_theta: float = 500000
use_scaled_rope: bool = False
@ -55,8 +54,8 @@ class ModelArgs:
vision_max_num_chunks: int = 4
vision_num_cross_attention_layers: int = -1
quantization_args: Optional[QuantizationArgs] = None
lora_args: Optional[LoRAArgs] = None
quantization_args: QuantizationArgs | None = None
lora_args: LoRAArgs | None = None
def __init__(self, **kwargs):
for k, v in kwargs.items():

View file

@ -8,7 +8,6 @@ import io
import json
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from PIL import Image as PIL_Image
@ -29,14 +28,14 @@ from .tool_utils import ToolUtils
@dataclass
class VisionInput:
mask: List[List[int]]
images: List[PIL_Image.Image]
mask: list[list[int]]
images: list[PIL_Image.Image]
@dataclass
class LLMInput:
tokens: List[int]
vision: Optional[VisionInput] = None
tokens: list[int]
vision: VisionInput | None = None
def role_str(role: Role) -> str:
@ -50,7 +49,7 @@ def role_str(role: Role) -> str:
class ChatFormat:
possible_headers: Dict[Role, str]
possible_headers: dict[Role, str]
def __init__(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer
@ -58,7 +57,7 @@ class ChatFormat:
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
self.vision_token = self.tokenizer.special_tokens["<|image|>"]
def _encode_header(self, role: str) -> List[int]:
def _encode_header(self, role: str) -> list[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
@ -70,7 +69,7 @@ class ChatFormat:
tokens, images = self._encode_content(content, bos=True)
return self._model_input_from_tokens_images(tokens, images)
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]:
def _encode_content(self, content: RawContent, bos: bool = False) -> tuple[list[int], list[PIL_Image.Image]]:
tokens = []
images = []
@ -107,7 +106,7 @@ class ChatFormat:
def encode_message(
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
) -> Tuple[List[int], List[PIL_Image.Image]]:
) -> tuple[list[int], list[PIL_Image.Image]]:
tokens = self._encode_header(message.role)
images = []
@ -145,8 +144,8 @@ class ChatFormat:
def encode_dialog_prompt(
self,
messages: List[RawMessage],
tool_prompt_format: Optional[ToolPromptFormat] = None,
messages: list[RawMessage],
tool_prompt_format: ToolPromptFormat | None = None,
) -> LLMInput:
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
tokens = []
@ -163,7 +162,7 @@ class ChatFormat:
return self._model_input_from_tokens_images(tokens, images)
# TODO(this should be generic, not only for assistant messages)
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
def decode_assistant_message(self, tokens: list[int], stop_reason: StopReason) -> RawMessage:
content = self.tokenizer.decode(tokens)
return self.decode_assistant_message_from_content(content, stop_reason)
@ -234,7 +233,7 @@ class ChatFormat:
tool_calls=tool_calls,
)
def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput:
def _model_input_from_tokens_images(self, tokens: list[int], images: list[PIL_Image.Image]) -> LLMInput:
vision_input = None
if len(images) > 0:
vision_input = VisionInput(
@ -249,9 +248,9 @@ class ChatFormat:
def create_vision_mask(
tokens: List[int],
tokens: list[int],
vision_token: int,
) -> List[List[int]]:
) -> list[list[int]]:
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
if len(vision_token_locations) == 0:
return []

View file

@ -15,8 +15,8 @@ import json
import os
import sys
import time
from collections.abc import Callable, Generator
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
@ -41,8 +41,8 @@ class Llama3:
ckpt_dir: str,
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
quantization_mode: Optional[QuantizationMode] = None,
world_size: int | None = None,
quantization_mode: QuantizationMode | None = None,
seed: int = 1,
device: str = "cuda",
):
@ -82,7 +82,7 @@ class Llama3:
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f:
with open(Path(ckpt_dir) / "params.json") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
@ -154,15 +154,15 @@ class Llama3:
@torch.inference_mode()
def generate(
self,
llm_inputs: List[LLMInput],
llm_inputs: list[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
max_gen_len: int | None = None,
logprobs: bool = False,
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator[List[GenerationResult], None, None]:
logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
) -> Generator[list[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
params = self.model.params
@ -302,13 +302,13 @@ class Llama3:
def completion(
self,
contents: List[RawContent],
contents: list[RawContent],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
max_gen_len: int | None = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
) -> Generator[list[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate(
model_inputs=model_inputs,
@ -324,14 +324,14 @@ class Llama3:
def chat_completion(
self,
messages_batch: List[List[RawMessage]],
messages_batch: list[list[RawMessage]],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
max_gen_len: int | None = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
) -> Generator[list[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate(
model_inputs=model_inputs,

View file

@ -12,7 +12,6 @@
# the top-level of this source tree.
from pathlib import Path
from typing import List, Optional
from termcolor import colored
@ -131,7 +130,7 @@ class LLama31Interface:
self.formatter = ChatFormat(self.tokenizer)
self.tool_prompt_format = tool_prompt_format
def get_tokens(self, messages: List[RawMessage]) -> List[int]:
def get_tokens(self, messages: list[RawMessage]) -> list[int]:
model_input = self.formatter.encode_dialog_prompt(
messages,
self.tool_prompt_format,
@ -149,10 +148,10 @@ class LLama31Interface:
def system_messages(
self,
builtin_tools: List[BuiltinTool],
custom_tools: List[ToolDefinition],
instruction: Optional[str] = None,
) -> List[RawMessage]:
builtin_tools: list[BuiltinTool],
custom_tools: list[ToolDefinition],
instruction: str | None = None,
) -> list[RawMessage]:
messages = []
default_gen = SystemDefaultGenerator()
@ -194,8 +193,8 @@ class LLama31Interface:
self,
content: str,
stop_reason: StopReason,
tool_call: Optional[ToolCall] = None,
) -> List[RawMessage]:
tool_call: ToolCall | None = None,
) -> list[RawMessage]:
tool_calls = []
if tool_call:
tool_calls.append(tool_call)
@ -208,7 +207,7 @@ class LLama31Interface:
)
]
def user_message(self, content: str) -> List[RawMessage]:
def user_message(self, content: str) -> list[RawMessage]:
return [RawMessage(role="user", content=content)]
def display_message_as_tokens(self, message: RawMessage) -> None:
@ -228,7 +227,7 @@ class LLama31Interface:
print("\n", end="")
def list_jinja_templates() -> List[Template]:
def list_jinja_templates() -> list[Template]:
return TEMPLATES

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import math
from typing import Optional, Tuple
import fairscale.nn.model_parallel.initialize as fs_init
import torch
@ -80,7 +79,7 @@ def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
@ -162,7 +161,7 @@ class Attention(nn.Module):
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
mask: torch.Tensor | None,
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
@ -204,7 +203,7 @@ class FeedForward(nn.Module):
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
ffn_dim_multiplier: float | None,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
@ -243,7 +242,7 @@ class TransformerBlock(nn.Module):
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
mask: torch.Tensor | None,
):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))

View file

@ -14,7 +14,7 @@
import math
from collections import defaultdict
from logging import getLogger
from typing import Any, Optional, Set, Tuple
from typing import Any
import torch
import torchvision.transforms as tv
@ -26,7 +26,7 @@ IMAGE_RES = 224
logger = getLogger()
class VariableSizeImageTransform(object):
class VariableSizeImageTransform:
"""
This class accepts images of any size and dynamically resize, pads and chunks it
based on the image aspect ratio and the number of image chunks we allow.
@ -75,7 +75,7 @@ class VariableSizeImageTransform(object):
self.resample = tv.InterpolationMode.BILINEAR
@staticmethod
def get_factors(n: int) -> Set[int]:
def get_factors(n: int) -> set[int]:
"""
Calculate all factors of a given number, i.e. a dividor that leaves
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
@ -145,9 +145,9 @@ class VariableSizeImageTransform(object):
@staticmethod
def get_max_res_without_distortion(
image_size: Tuple[int, int],
target_size: Tuple[int, int],
) -> Tuple[int, int]:
image_size: tuple[int, int],
target_size: tuple[int, int],
) -> tuple[int, int]:
"""
Determines the maximum resolution to which an image can be resized to without distorting its
aspect ratio, based on the target resolution.
@ -198,8 +198,8 @@ class VariableSizeImageTransform(object):
def resize_without_distortion(
self,
image: torch.Tensor,
target_size: Tuple[int, int],
max_upscaling_size: Optional[int],
target_size: tuple[int, int],
max_upscaling_size: int | None,
) -> torch.Tensor:
"""
Used to resize an image to target_resolution, without distortion.
@ -261,10 +261,10 @@ class VariableSizeImageTransform(object):
def get_best_fit(
self,
image_size: Tuple[int, int],
image_size: tuple[int, int],
possible_resolutions: torch.Tensor,
resize_to_max_canvas: bool = False,
) -> Tuple[int, int]:
) -> tuple[int, int]:
"""
Determines the best canvas possible from a list of possible resolutions to, without distortion,
resize an image to.
@ -364,7 +364,7 @@ class VariableSizeImageTransform(object):
max_num_chunks: int,
normalize_img: bool = True,
resize_to_max_canvas: bool = False,
) -> Tuple[Any, Any]:
) -> tuple[Any, Any]:
"""
Args:
image (PIL.Image): Image to be resized.

View file

@ -6,8 +6,9 @@
import logging
import math
from collections.abc import Callable
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any
import fairscale.nn.model_parallel.initialize as fs_init
import torch
@ -104,9 +105,9 @@ class ColumnParallelConv2dPatch(torch.nn.Module):
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
bias: Optional[bool] = False,
kernel_size: int | tuple[int, int],
stride: int | tuple[int, int],
bias: bool | None = False,
) -> None:
super().__init__()
if isinstance(kernel_size, int):
@ -390,13 +391,13 @@ class VisionEncoder(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool = True,
missing_keys: List[str] = None,
unexpected_keys: List[str] = None,
error_msgs: List[str] = None,
missing_keys: list[str] = None,
unexpected_keys: list[str] = None,
error_msgs: list[str] = None,
return_state_dict: bool = False,
) -> None:
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
@ -641,7 +642,7 @@ class FeedForward(nn.Module):
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
ffn_dim_multiplier: float | None,
):
"""
Initialize the FeedForward module.
@ -983,7 +984,7 @@ class CrossAttentionTransformerBlock(torch.nn.Module):
self,
x: torch.Tensor,
xattn_mask: torch.Tensor,
full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor],
full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor],
xattn_cache: torch.Tensor,
) -> torch.Tensor:
_attn_out = self.attention(
@ -1144,7 +1145,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
def _init_fusion_schedule(
self,
num_layers: int,
) -> List[int]:
) -> list[int]:
llama_layers = list(range(self.n_llama_layers))
# uniformly spread the layers
@ -1231,7 +1232,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
text_dtype,
vision_tokens,
cross_attention_masks,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
assert vision_tokens is not None, "Vision tokens must be provided"
vision_seqlen = vision_tokens.shape[3]
assert vision_tokens.shape[1] == cross_attention_masks.shape[2], (
@ -1280,11 +1281,11 @@ class CrossAttentionTransformer(torch.nn.Module):
def compute_vision_tokens_masks(
self,
batch_images: List[List[PIL_Image.Image]],
batch_masks: List[List[List[int]]],
batch_images: list[list[PIL_Image.Image]],
batch_masks: list[list[list[int]]],
total_len: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
skip_vision_encoder = False
assert len(batch_images) == len(batch_masks), "Images and masks must have the same length"
@ -1371,11 +1372,11 @@ class CrossAttentionTransformer(torch.nn.Module):
def _stack_images(
images: List[List[PIL_Image.Image]],
images: list[list[PIL_Image.Image]],
max_num_chunks: int,
image_res: int,
max_num_images: int,
) -> Tuple[torch.Tensor, List[int]]:
) -> tuple[torch.Tensor, list[int]]:
"""
Takes a list of list of images and stacks them into a tensor.
This function is needed since images can be of completely
@ -1400,8 +1401,8 @@ def _stack_images(
def _pad_masks(
all_masks: List[List[List[int]]],
all_num_chunks: List[List[int]],
all_masks: list[list[list[int]]],
all_num_chunks: list[list[int]],
total_len: int,
max_num_chunks: int,
) -> torch.Tensor:

View file

@ -12,7 +12,7 @@
# the top-level of this source tree.
from dataclasses import dataclass
from typing import Any, Dict, List
from typing import Any
from jinja2 import Template
@ -20,7 +20,7 @@ from jinja2 import Template
@dataclass
class PromptTemplate:
template: str
data: Dict[str, Any]
data: dict[str, Any]
def render(self):
template = Template(self.template)
@ -35,5 +35,5 @@ class PromptTemplateGeneratorBase:
def gen(self, *args, **kwargs) -> PromptTemplate:
raise NotImplementedError()
def data_examples(self) -> List[Any]:
def data_examples(self) -> list[Any]:
raise NotImplementedError()

View file

@ -13,7 +13,7 @@
import textwrap
from datetime import datetime
from typing import Any, List, Optional
from typing import Any
from llama_stack.apis.inference import (
BuiltinTool,
@ -39,12 +39,12 @@ class SystemDefaultGenerator(PromptTemplateGeneratorBase):
},
)
def data_examples(self) -> List[Any]:
def data_examples(self) -> list[Any]:
return [None]
class BuiltinToolGenerator(PromptTemplateGeneratorBase):
def _tool_breakdown(self, tools: List[ToolDefinition]):
def _tool_breakdown(self, tools: list[ToolDefinition]):
builtin_tools, custom_tools = [], []
for dfn in tools:
if isinstance(dfn.tool_name, BuiltinTool):
@ -54,7 +54,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
return builtin_tools, custom_tools
def gen(self, tools: List[ToolDefinition]) -> PromptTemplate:
def gen(self, tools: list[ToolDefinition]) -> PromptTemplate:
builtin_tools, custom_tools = self._tool_breakdown(tools)
template_str = textwrap.dedent(
"""
@ -75,7 +75,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
},
)
def data_examples(self) -> List[List[ToolDefinition]]:
def data_examples(self) -> list[list[ToolDefinition]]:
return [
# builtin tools
[
@ -91,7 +91,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
def gen(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
Answer the user's question by making use of the following functions if needed.
@ -137,7 +137,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
{"custom_tools": [t.model_dump() for t in custom_tools]},
)
def data_examples(self) -> List[List[ToolDefinition]]:
def data_examples(self) -> list[list[ToolDefinition]]:
return [
[
ToolDefinition(
@ -161,7 +161,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
def gen(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
You have access to the following functions:
@ -199,7 +199,7 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
{"custom_tools": [t.model_dump() for t in custom_tools]},
)
def data_examples(self) -> List[List[ToolDefinition]]:
def data_examples(self) -> list[list[ToolDefinition]]:
return [
[
ToolDefinition(
@ -238,14 +238,14 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
""".strip("\n")
)
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
def gen(self, custom_tools: list[ToolDefinition], system_prompt: str | None = None) -> PromptTemplate:
system_prompt = system_prompt or self.DEFAULT_PROMPT
return PromptTemplate(
system_prompt,
{"function_description": self._gen_function_description(custom_tools)},
)
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
Here is a list of functions in JSON format that you can invoke.
@ -291,7 +291,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{"tools": [t.model_dump() for t in custom_tools]},
).render()
def data_examples(self) -> List[List[ToolDefinition]]:
def data_examples(self) -> list[list[ToolDefinition]]:
return [
[
ToolDefinition(

View file

@ -12,7 +12,6 @@
# the top-level of this source tree.
import textwrap
from typing import Optional
from .base import PromptTemplate, PromptTemplateGeneratorBase
@ -21,8 +20,8 @@ class ToolResponseGenerator(PromptTemplateGeneratorBase):
def gen(
self,
status: str,
stdout: Optional[str] = None,
stderr: Optional[str] = None,
stdout: str | None = None,
stderr: str | None = None,
):
assert status in [
"success",

View file

@ -6,7 +6,7 @@
# type: ignore
import os
from typing import Any, Dict, List, Optional, cast
from typing import Any, cast
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
@ -37,9 +37,9 @@ def swiglu_wrapper(
def convert_to_quantized_model(
model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str,
quantization_mode: Optional[str] = None,
fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
quantization_mode: str | None = None,
fp8_activation_scale_ub: float | None = 1200.0,
device: torch.device | None = None,
) -> Transformer | CrossAttentionTransformer:
if quantization_mode == QuantizationMode.fp8_mixed:
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
@ -52,8 +52,8 @@ def convert_to_quantized_model(
def convert_to_fp8_quantized_model(
model: Transformer,
checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
fp8_activation_scale_ub: float | None = 1200.0,
device: torch.device | None = None,
) -> Transformer:
# Move weights to GPU with quantization
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
@ -122,8 +122,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
# LoRA parameters
lora_rank: Optional[int] = None,
lora_scale: Optional[float] = None,
lora_rank: int | None = None,
lora_scale: float | None = None,
) -> None:
super().__init__(
in_features,
@ -134,8 +134,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
precision=precision,
scales_precision=scales_precision,
)
self.lora_scale: Optional[float] = None
self.adaptor: Optional[nn.Sequential] = None
self.lora_scale: float | None = None
self.adaptor: nn.Sequential | None = None
if lora_rank is not None:
assert lora_scale is not None, "Please specify lora scale for LoRA."
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
@ -147,13 +147,13 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
"""A hook to load the quantized weights from the state dict."""
if prefix + "zeros" not in state_dict:
@ -191,13 +191,13 @@ class Int8WeightEmbedding(torch.nn.Embedding):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
"""A hook to load the quantized embedding weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
@ -221,13 +221,13 @@ class Int8WeightLinear(torch.nn.Linear):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
"""A hook to load the quantized linear weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
@ -238,8 +238,8 @@ class Int8WeightLinear(torch.nn.Linear):
def _prepare_model_int4_weight_int8_dynamic_activation(
model: torch.nn.Module,
group_size: int,
lora_rank: Optional[int],
lora_scale: Optional[float],
lora_rank: int | None,
lora_scale: float | None,
):
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
@ -265,7 +265,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
)
del module
setattr(model, module_name, quantized_module)
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
elif isinstance(module, ColumnParallelLinear | RowParallelLinear | nn.Linear):
quantized_module = Int8DynActInt4WeightLinearLoRA(
in_features=module.in_features,
out_features=module.out_features,
@ -286,7 +286,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
def convert_to_int4_quantized_model(
model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str,
device: Optional[torch.device] = None,
device: torch.device | None = None,
) -> Transformer | CrossAttentionTransformer:
"""Convert the model to int4 quantized model."""
model_args = model.params

View file

@ -5,18 +5,11 @@
# the root directory of this source tree.
import os
from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger
from pathlib import Path
from typing import (
AbstractSet,
Collection,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Union,
cast,
)
@ -44,7 +37,7 @@ class Tokenizer:
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
special_tokens: Dict[str, int]
special_tokens: dict[str, int]
num_reserved_special_tokens = 256
@ -116,9 +109,9 @@ class Tokenizer:
*,
bos: bool,
eos: bool,
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
disallowed_special: Union[Literal["all"], Collection[str]] = (),
) -> List[int]:
allowed_special: Literal["all"] | Set[str] | None = None,
disallowed_special: Literal["all"] | Collection[str] = (),
) -> list[int]:
"""
Encodes a string into a list of token IDs.
@ -151,7 +144,7 @@ class Tokenizer:
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
)
)
t: List[int] = []
t: list[int] = []
for substr in substrs:
t.extend(
self.model.encode(
@ -177,7 +170,7 @@ class Tokenizer:
str: The decoded string.
"""
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(List[int], t))
return self.model.decode(cast(list[int], t))
@staticmethod
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:

View file

@ -6,7 +6,6 @@
import json
import re
from typing import Optional, Tuple
from llama_stack.log import get_logger
@ -172,7 +171,7 @@ class ToolUtils:
return match is not None
@staticmethod
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
def maybe_extract_builtin_tool_call(message_body: str) -> tuple[str, str] | None:
# Find the first match in the text
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
@ -185,7 +184,7 @@ class ToolUtils:
return None
@staticmethod
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
def maybe_extract_custom_tool_call(message_body: str) -> tuple[str, str] | None:
# NOTE: Custom function too calls are still experimental
# Sometimes, response is of the form
# {"type": "function", "name": "function_name", "parameters": {...}
@ -252,7 +251,7 @@ class ToolUtils:
def format_value(value: RecursiveType) -> str:
if isinstance(value, str):
return f'"{value}"'
elif isinstance(value, (int, float, bool)) or value is None:
elif isinstance(value, int | float | bool) or value is None:
return str(value)
elif isinstance(value, list):
return f"[{', '.join(format_value(v) for v in value)}]"

View file

@ -12,7 +12,6 @@
# the top-level of this source tree.
import textwrap
from typing import List
from llama_stack.models.llama.datatypes import (
BuiltinTool,
@ -73,7 +72,7 @@ def wolfram_alpha_response():
)
def usecases() -> List[UseCase | str]:
def usecases() -> list[UseCase | str]:
return [
textwrap.dedent(
"""

View file

@ -12,7 +12,6 @@
# the top-level of this source tree.
import textwrap
from typing import List
from llama_stack.models.llama.datatypes import (
BuiltinTool,
@ -74,7 +73,7 @@ def wolfram_alpha_response():
)
def usecases() -> List[UseCase | str]:
def usecases() -> list[UseCase | str]:
return [
textwrap.dedent(
"""

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from enum import Enum
from typing import Optional
from pydantic import BaseModel, model_validator
@ -15,8 +14,8 @@ class QuantizationScheme(Enum):
class QuantizationArgs(BaseModel):
scheme: Optional[QuantizationScheme] = None
group_size: Optional[int] = None
scheme: QuantizationScheme | None = None
group_size: int | None = None
spinquant: bool = False
@ -58,32 +57,32 @@ class ModelArgs(BaseModel):
dim: int = -1
n_layers: int = -1
n_heads: int = -1
n_kv_heads: Optional[int] = None
head_dim: Optional[int] = None
n_kv_heads: int | None = None
head_dim: int | None = None
vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
ffn_exp: Optional[float] = None
ffn_dim_multiplier: float | None = None
ffn_exp: float | None = None
norm_eps: float = 1e-5
attention_chunk_size: Optional[int] = None
attention_chunk_size: int | None = None
rope_theta: float = 500000
use_scaled_rope: bool = False
rope_scaling_factor: Optional[float] = None
rope_high_freq_factor: Optional[float] = None
rope_scaling_factor: float | None = None
rope_high_freq_factor: float | None = None
nope_layer_interval: Optional[int] = None # No position encoding in every n layers
nope_layer_interval: int | None = None # No position encoding in every n layers
use_qk_norm: bool = False
# Set to True to enable inference-time temperature tuning (useful for very long context)
attn_temperature_tuning: bool = False
floor_scale: float = 8192.0
attn_scale: float = 0.1
vision_args: Optional[VisionArgs] = None
moe_args: Optional[MoEArgs] = None
quantization_args: Optional[QuantizationArgs] = None
lora_args: Optional[LoRAArgs] = None
vision_args: VisionArgs | None = None
moe_args: MoEArgs | None = None
quantization_args: QuantizationArgs | None = None
lora_args: LoRAArgs | None = None
max_batch_size: int = 32
max_seq_len: int = 2048

View file

@ -8,7 +8,6 @@ import io
import json
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from PIL import Image as PIL_Image
@ -46,10 +45,10 @@ def role_str(role: Role) -> str:
class TransformedImage:
image_tiles: torch.Tensor
# is the aspect ratio needed anywhere?
aspect_ratio: Tuple[int, int]
aspect_ratio: tuple[int, int]
def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
def convert_image_to_rgb(image: PIL_Image.Image, bg: tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
if image.mode == "RGBA":
image.load() # for png.split()
new_img = PIL_Image.new("RGB", image.size, bg)
@ -59,12 +58,12 @@ def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255
class ChatFormat:
possible_headers: Dict[Role, str]
possible_headers: dict[Role, str]
def __init__(
self,
tokenizer: Tokenizer,
vision_args: Optional[VisionArgs] = None,
vision_args: VisionArgs | None = None,
max_num_chunks: int = 16,
):
self.tokenizer = tokenizer
@ -81,7 +80,7 @@ class ChatFormat:
vision_args.image_size.width, vision_args.image_size.height
)
def _encode_header(self, role: str) -> List[int]:
def _encode_header(self, role: str) -> list[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|header_start|>"])
@ -98,7 +97,7 @@ class ChatFormat:
def _encode_image(
self,
transformed_image: TransformedImage,
) -> List[int]:
) -> list[int]:
assert self.vision_args is not None, "The model is not vision-enabled"
image_tensor = transformed_image.image_tiles
@ -140,7 +139,7 @@ class ChatFormat:
return tokens
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[TransformedImage]]:
def _encode_content(self, content: RawContent, bos: bool = False) -> tuple[list[int], list[TransformedImage]]:
tokens = []
tranformed_images = []
@ -189,7 +188,7 @@ class ChatFormat:
def encode_message(
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
) -> Tuple[List[int], List[TransformedImage]]:
) -> tuple[list[int], list[TransformedImage]]:
tokens = self._encode_header(message.role)
images = []
@ -223,7 +222,7 @@ class ChatFormat:
def encode_dialog_prompt(
self,
messages: List[RawMessage],
messages: list[RawMessage],
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> LLMInput:
tokens = []
@ -240,7 +239,7 @@ class ChatFormat:
return self._model_input_from_tokens_images(tokens, images)
# TODO(this should be generic, not only for assistant messages)
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
def decode_assistant_message(self, tokens: list[int], stop_reason: StopReason) -> RawMessage:
content = self.tokenizer.decode(tokens)
return self.decode_assistant_message_from_content(content, stop_reason)
@ -312,7 +311,7 @@ class ChatFormat:
tool_calls=tool_calls,
)
def _model_input_from_tokens_images(self, tokens: List[int], images: List[TransformedImage]) -> LLMInput:
def _model_input_from_tokens_images(self, tokens: list[int], images: list[TransformedImage]) -> LLMInput:
return LLMInput(
tokens=tokens,
images=[x.image_tiles for x in images] if len(images) > 0 else None,

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
@ -30,7 +29,7 @@ class LLMInput:
tokens: torch.Tensor
# images are already pre-processed (resized, tiled, etc.)
images: Optional[List[torch.Tensor]] = None
images: list[torch.Tensor] | None = None
@dataclass
@ -45,8 +44,8 @@ class TransformerInput:
# tokens_position defines the position of the tokens in each batch,
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
# - when it is an int, the start position are the same for all batches
tokens_position: Union[torch.Tensor, int]
image_embedding: Optional[MaskedEmbedding] = None
tokens_position: torch.Tensor | int
image_embedding: MaskedEmbedding | None = None
@dataclass

View file

@ -11,7 +11,7 @@
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from typing import Any, Dict, List
from typing import Any
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
@ -36,13 +36,13 @@ class FeedForward(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
if prefix + "mlp.fc1_weight" in state_dict:
w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0)

View file

@ -10,8 +10,8 @@ import json
import os
import sys
import time
from collections.abc import Callable, Generator
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
@ -38,8 +38,8 @@ class Llama4:
ckpt_dir: str,
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
quantization_mode: Optional[QuantizationMode] = None,
world_size: int | None = None,
quantization_mode: QuantizationMode | None = None,
seed: int = 1,
):
if not torch.distributed.is_initialized():
@ -63,7 +63,7 @@ class Llama4:
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f:
with open(Path(ckpt_dir) / "params.json") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
@ -117,15 +117,15 @@ class Llama4:
@torch.inference_mode()
def generate(
self,
llm_inputs: List[LLMInput],
llm_inputs: list[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
max_gen_len: int | None = None,
logprobs: bool = False,
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator[List[GenerationResult], None, None]:
logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
) -> Generator[list[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
max_gen_len = self.model.args.max_seq_len - 1
@ -245,13 +245,13 @@ class Llama4:
def completion(
self,
contents: List[RawContent],
contents: list[RawContent],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
max_gen_len: int | None = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
) -> Generator[list[GenerationResult], None, None]:
llm_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate(
llm_inputs=llm_inputs,
@ -267,13 +267,13 @@ class Llama4:
def chat_completion(
self,
messages_batch: List[List[RawMessage]],
messages_batch: list[list[RawMessage]],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
max_gen_len: int | None = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
) -> Generator[list[GenerationResult], None, None]:
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate(
llm_inputs=llm_inputs,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import math
from typing import Any, Dict, List, Optional, Tuple
from typing import Any
import fairscale.nn.model_parallel.initialize as fs_init
import torch
@ -89,7 +89,7 @@ def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
@ -174,13 +174,13 @@ class Attention(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
if prefix + "wqkv.weight" in state_dict:
wqkv = state_dict.pop(prefix + "wqkv.weight")
@ -200,7 +200,7 @@ class Attention(nn.Module):
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor] = None,
mask: torch.Tensor | None = None,
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
@ -288,13 +288,13 @@ class TransformerBlock(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight")
@ -318,8 +318,8 @@ class TransformerBlock(nn.Module):
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
global_attn_mask: Optional[torch.Tensor],
local_attn_mask: Optional[torch.Tensor],
global_attn_mask: torch.Tensor | None,
local_attn_mask: torch.Tensor | None,
):
# The iRoPE architecture uses global attention mask for NoPE layers or
# if chunked local attention is not used
@ -374,13 +374,13 @@ class Transformer(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
if prefix + "rope.freqs" in state_dict:
state_dict.pop(prefix + "rope.freqs")

View file

@ -6,7 +6,7 @@
# ruff: noqa: N806
# pyre-strict
from typing import Any, Dict, List
from typing import Any
import fairscale.nn.model_parallel.initialize as fs_init
import torch
@ -63,13 +63,13 @@ class Experts(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
self.prefix = prefix
if prefix + "moe_w_in_eD_F" in state_dict:
@ -158,13 +158,13 @@ class MoE(torch.nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
if prefix + "w_in_shared_FD.weight" in state_dict:
state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
@ -210,5 +210,5 @@ class MoE(torch.nn.Module):
def divide_exact(numerator: int, denominator: int) -> int:
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
return numerator // denominator

View file

@ -13,7 +13,6 @@
import math
from collections import defaultdict
from typing import Optional, Set, Tuple
import torch
import torchvision.transforms as tv
@ -52,7 +51,7 @@ class ResizeNormalizeImageTransform:
return self.tv_transform(image)
class VariableSizeImageTransform(object):
class VariableSizeImageTransform:
"""
This class accepts images of any size and dynamically resize, pads and chunks it
based on the image aspect ratio and the number of image chunks we allow.
@ -100,7 +99,7 @@ class VariableSizeImageTransform(object):
self.resample = tv.InterpolationMode.BILINEAR
@staticmethod
def get_factors(n: int) -> Set[int]:
def get_factors(n: int) -> set[int]:
"""
Calculate all factors of a given number, i.e. a dividor that leaves
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
@ -170,9 +169,9 @@ class VariableSizeImageTransform(object):
@staticmethod
def get_max_res_without_distortion(
image_size: Tuple[int, int],
target_size: Tuple[int, int],
) -> Tuple[int, int]:
image_size: tuple[int, int],
target_size: tuple[int, int],
) -> tuple[int, int]:
"""
Determines the maximum resolution to which an image can be resized to without distorting its
aspect ratio, based on the target resolution.
@ -223,8 +222,8 @@ class VariableSizeImageTransform(object):
def resize_without_distortion(
self,
image: torch.Tensor,
target_size: Tuple[int, int],
max_upscaling_size: Optional[int],
target_size: tuple[int, int],
max_upscaling_size: int | None,
) -> torch.Tensor:
"""
Used to resize an image to target_resolution, without distortion.
@ -289,10 +288,10 @@ class VariableSizeImageTransform(object):
def get_best_fit(
self,
image_size: Tuple[int, int],
image_size: tuple[int, int],
possible_resolutions: torch.Tensor,
resize_to_max_canvas: bool = False,
) -> Tuple[int, int]:
) -> tuple[int, int]:
"""
Determines the best canvas possible from a list of possible resolutions to, without distortion,
resize an image to.
@ -392,7 +391,7 @@ class VariableSizeImageTransform(object):
max_num_chunks: int,
normalize_img: bool = True,
resize_to_max_canvas: bool = False,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
) -> tuple[torch.Tensor, tuple[int, int]]:
"""
Args:
image (PIL.Image): Image to be resized.

View file

@ -12,7 +12,6 @@
# the top-level of this source tree.
import textwrap
from typing import List, Optional
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
from llama_stack.models.llama.llama3.prompt_templates.base import (
@ -67,14 +66,14 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
""".strip("\n")
)
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
def gen(self, custom_tools: list[ToolDefinition], system_prompt: str | None = None) -> PromptTemplate:
system_prompt = system_prompt or self.DEFAULT_PROMPT
return PromptTemplate(
system_prompt,
{"function_description": self._gen_function_description(custom_tools)},
)
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
Here is a list of functions in JSON format that you can invoke.
@ -120,7 +119,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{"tools": [t.model_dump() for t in custom_tools]},
).render()
def data_examples(self) -> List[List[ToolDefinition]]:
def data_examples(self) -> list[list[ToolDefinition]]:
return [
[
ToolDefinition(

View file

@ -7,7 +7,6 @@
import textwrap
from io import BytesIO
from pathlib import Path
from typing import List
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator,
@ -23,7 +22,7 @@ from ..prompt_format import (
THIS_DIR = Path(__file__).parent
def usecases(base_model: bool = False) -> List[UseCase | str]:
def usecases(base_model: bool = False) -> list[UseCase | str]:
with open(THIS_DIR.parent / "resources/small_dog.jpg", "rb") as f:
img_small_dog = f.read()
with open(THIS_DIR.parent / "resources/dog.jpg", "rb") as f:

View file

@ -6,7 +6,7 @@
import logging
import os
from typing import Callable, Optional
from collections.abc import Callable
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
@ -45,8 +45,8 @@ def experts_batched_swiglu_wrapper(
def convert_to_quantized_model(
model: Transformer,
checkpoint_dir: str,
quantization_mode: Optional[str] = None,
fp8_activation_scale_ub: Optional[float] = 1200.0,
quantization_mode: str | None = None,
fp8_activation_scale_ub: float | None = 1200.0,
use_rich_progress: bool = True,
) -> Transformer:
from ...quantize_impls import (
@ -213,7 +213,7 @@ def logging_callbacks(
)
task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting")
def update_status(message: Optional[str], completed: Optional[int] = None) -> None:
def update_status(message: str | None, completed: int | None = None) -> None:
if use_rich_progress:
if message is not None:
progress.update(task_id, status=message)

View file

@ -5,18 +5,11 @@
# the root directory of this source tree.
import os
from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger
from pathlib import Path
from typing import (
AbstractSet,
Collection,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Union,
cast,
)
@ -114,7 +107,7 @@ class Tokenizer:
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
special_tokens: Dict[str, int]
special_tokens: dict[str, int]
num_reserved_special_tokens = 2048
@ -182,9 +175,9 @@ class Tokenizer:
*,
bos: bool,
eos: bool,
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
disallowed_special: Union[Literal["all"], Collection[str]] = (),
) -> List[int]:
allowed_special: Literal["all"] | Set[str] | None = None,
disallowed_special: Literal["all"] | Collection[str] = (),
) -> list[int]:
"""
Encodes a string into a list of token IDs.
@ -217,7 +210,7 @@ class Tokenizer:
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
)
)
t: List[int] = []
t: list[int] = []
for substr in substrs:
t.extend(
self.model.encode(
@ -243,7 +236,7 @@ class Tokenizer:
str: The decoded string.
"""
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(List[int], t))
return self.model.decode(cast(list[int], t))
@staticmethod
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
import math
from typing import Any, Callable, Dict, List
from collections.abc import Callable
from typing import Any
import torch
import torch.nn as nn
@ -136,13 +137,13 @@ class VisionEmbeddings(torch.nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool = True,
missing_keys: List[str] = None,
unexpected_keys: List[str] = None,
error_msgs: List[str] = None,
missing_keys: list[str] = None,
unexpected_keys: list[str] = None,
error_msgs: list[str] = None,
return_state_dict: bool = False,
) -> None:
original_sd = self.state_dict()
@ -163,7 +164,7 @@ class VisionEmbeddings(torch.nn.Module):
# each image is a tensor of shape [num_tiles, C, H, W]
def forward(
self,
image_batch: List[List[torch.Tensor]],
image_batch: list[list[torch.Tensor]],
image_mask: torch.Tensor,
h_ref: torch.Tensor,
) -> torch.Tensor:

View file

@ -4,7 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from collections.abc import Callable
from typing import Any
import fairscale.nn.model_parallel.initialize as fs_init
import torch
@ -42,9 +43,9 @@ class ColumnParallelConv2dPatch(torch.nn.Module):
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
bias: Optional[bool] = False,
kernel_size: int | tuple[int, int],
stride: int | tuple[int, int],
bias: bool | None = False,
) -> None:
super().__init__()
if isinstance(kernel_size, int):
@ -134,15 +135,15 @@ class _TransformerBlock(nn.Module):
def attention(
self,
x: torch.Tensor,
freq_cis: Optional[torch.Tensor] = None,
freq_cis: torch.Tensor | None = None,
):
return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
freq_cis: Optional[torch.Tensor] = None,
mask: torch.Tensor | None = None,
freq_cis: torch.Tensor | None = None,
):
_gate_attn = 1 if not self.gated else self.gate_attn.tanh()
_gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
@ -210,8 +211,8 @@ class PackingIndex:
class VisionEncoder(nn.Module):
def __init__(
self,
image_size: Tuple[int, int],
patch_size: Tuple[int, int],
image_size: tuple[int, int],
patch_size: tuple[int, int],
dim: int,
layers: int,
heads: int,
@ -299,13 +300,13 @@ class VisionEncoder(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool = True,
missing_keys: List[str] = None,
unexpected_keys: List[str] = None,
error_msgs: List[str] = None,
missing_keys: list[str] = None,
unexpected_keys: list[str] = None,
error_msgs: list[str] = None,
return_state_dict: bool = False,
) -> None:
orig_pos_embed = state_dict.get(prefix + "positional_embedding")

View file

@ -14,7 +14,6 @@
import json
import textwrap
from pathlib import Path
from typing import List
from pydantic import BaseModel, Field
@ -44,7 +43,7 @@ class TextCompletionContent(BaseModel):
class UseCase(BaseModel):
title: str = ""
description: str = ""
dialogs: List[List[RawMessage] | TextCompletionContent | str] = Field(default_factory=list)
dialogs: list[list[RawMessage] | TextCompletionContent | str] = Field(default_factory=list)
notes: str = ""
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json
max_gen_len: int = 512

View file

@ -7,7 +7,6 @@
# type: ignore
import collections
import logging
from typing import Optional, Tuple, Type, Union
log = logging.getLogger(__name__)
@ -27,7 +26,7 @@ class Fp8ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters
# with our custom Fp8Weights instance. Do this properly.
@property
def __class__(self) -> Type[nn.parameter.Parameter]:
def __class__(self) -> type[nn.parameter.Parameter]:
return nn.Parameter
@property
@ -51,7 +50,7 @@ class Int4ScaledWeights:
# TODO: Ugly trick so torch allows us to replace parameters
# with our custom Int4Weights instance. Do this properly.
@property
def __class__(self) -> Type[nn.parameter.Parameter]:
def __class__(self) -> type[nn.parameter.Parameter]:
return nn.Parameter
@property
@ -74,7 +73,7 @@ class Int4Weights(
def int4_row_quantize(
x: torch.Tensor,
group_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
n_bit = 4 # Number of target bits.
to_quant = x.reshape(-1, group_size).to(torch.float)
@ -115,8 +114,8 @@ def pack_int4(x: torch.Tensor) -> torch.Tensor:
def bmm_nt(
x: Tensor,
w: Union[Fp8RowwiseWeights, Int4Weights],
num_tokens: Optional[Tensor] = None,
w: Fp8RowwiseWeights | Int4Weights,
num_tokens: Tensor | None = None,
) -> Tensor:
if isinstance(w, Fp8ScaledWeights):
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub)
@ -129,10 +128,10 @@ def bmm_nt(
def ffn_swiglu(
x: Tensor,
w1: Union[Fp8RowwiseWeights, Int4Weights],
w3: Union[Fp8RowwiseWeights, Int4Weights],
w2: Union[Fp8RowwiseWeights, Int4Weights],
num_tokens: Optional[Tensor] = None,
w1: Fp8RowwiseWeights | Int4Weights,
w3: Fp8RowwiseWeights | Int4Weights,
w2: Fp8RowwiseWeights | Int4Weights,
num_tokens: Tensor | None = None,
is_memory_bounded: bool = False,
) -> Tensor:
if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or (
@ -158,7 +157,7 @@ def ffn_swiglu(
def quantize_fp8(
w: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
output_device: torch.device | None = None,
) -> Fp8RowwiseWeights:
"""Quantize [n, k] weight tensor.
@ -184,7 +183,7 @@ def quantize_fp8(
@torch.inference_mode()
def quantize_int4(
w: Tensor,
output_device: Optional[torch.device] = None,
output_device: torch.device | None = None,
) -> Int4Weights:
"""Quantize [n, k/2] weight tensor.
@ -213,7 +212,7 @@ def load_fp8(
w: Tensor,
w_scale: Tensor,
fp8_activation_scale_ub: float,
output_device: Optional[torch.device] = None,
output_device: torch.device | None = None,
) -> Fp8RowwiseWeights:
"""Load FP8 [n, k] weight tensor.
@ -239,7 +238,7 @@ def load_int4(
w: Tensor,
scale: Tensor,
zero_point: Tensor,
output_device: Optional[torch.device] = None,
output_device: torch.device | None = None,
) -> Int4Weights:
"""Load INT4 [n, k/2] weight tensor.
@ -256,9 +255,9 @@ def load_int4(
def fc_dynamic(
x: Tensor,
w: Union[Fp8RowwiseWeights, Int4Weights],
activation_scale_ub: Optional[Tensor] = None,
num_tokens: Optional[Tensor] = None,
w: Fp8RowwiseWeights | Int4Weights,
activation_scale_ub: Tensor | None = None,
num_tokens: Tensor | None = None,
is_memory_bounded: bool = False,
) -> Tensor:
"""
@ -275,11 +274,11 @@ def fc_dynamic(
def ffn_swiglu_dynamic(
x: Tensor,
w1: Union[Fp8RowwiseWeights, Int4Weights],
w3: Union[Fp8RowwiseWeights, Int4Weights],
w2: Union[Fp8RowwiseWeights, Int4Weights],
activation_scale_ub: Optional[Tensor] = None,
num_tokens: Optional[Tensor] = None,
w1: Fp8RowwiseWeights | Int4Weights,
w3: Fp8RowwiseWeights | Int4Weights,
w2: Fp8RowwiseWeights | Int4Weights,
activation_scale_ub: Tensor | None = None,
num_tokens: Tensor | None = None,
is_memory_bounded: bool = False,
) -> Tensor:
assert x.dim() == 3 or x.dim() == 2

View file

@ -6,7 +6,6 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional
from .sku_types import (
CheckpointQuantizationFormat,
@ -19,14 +18,14 @@ LLAMA2_VOCAB_SIZE = 32000
LLAMA3_VOCAB_SIZE = 128256
def resolve_model(descriptor: str) -> Optional[Model]:
def resolve_model(descriptor: str) -> Model | None:
for m in all_registered_models():
if descriptor in (m.descriptor(), m.huggingface_repo):
return m
return None
def all_registered_models() -> List[Model]:
def all_registered_models() -> list[Model]:
return (
llama2_family()
+ llama3_family()
@ -38,48 +37,48 @@ def all_registered_models() -> List[Model]:
)
def llama2_family() -> List[Model]:
def llama2_family() -> list[Model]:
return [
*llama2_base_models(),
*llama2_instruct_models(),
]
def llama3_family() -> List[Model]:
def llama3_family() -> list[Model]:
return [
*llama3_base_models(),
*llama3_instruct_models(),
]
def llama3_1_family() -> List[Model]:
def llama3_1_family() -> list[Model]:
return [
*llama3_1_base_models(),
*llama3_1_instruct_models(),
]
def llama3_2_family() -> List[Model]:
def llama3_2_family() -> list[Model]:
return [
*llama3_2_base_models(),
*llama3_2_instruct_models(),
]
def llama3_3_family() -> List[Model]:
def llama3_3_family() -> list[Model]:
return [
*llama3_3_instruct_models(),
]
def llama4_family() -> List[Model]:
def llama4_family() -> list[Model]:
return [
*llama4_base_models(),
*llama4_instruct_models(),
]
def llama4_base_models() -> List[Model]:
def llama4_base_models() -> list[Model]:
return [
Model(
core_model_id=CoreModelId.llama4_scout_17b_16e,
@ -98,7 +97,7 @@ def llama4_base_models() -> List[Model]:
]
def llama4_instruct_models() -> List[Model]:
def llama4_instruct_models() -> list[Model]:
return [
Model(
core_model_id=CoreModelId.llama4_scout_17b_16e_instruct,
@ -126,7 +125,7 @@ def llama4_instruct_models() -> List[Model]:
]
def llama2_base_models() -> List[Model]:
def llama2_base_models() -> list[Model]:
return [
Model(
core_model_id=CoreModelId.llama2_7b,
@ -185,7 +184,7 @@ def llama2_base_models() -> List[Model]:
]
def llama3_base_models() -> List[Model]:
def llama3_base_models() -> list[Model]:
return [
Model(
core_model_id=CoreModelId.llama3_8b,
@ -226,7 +225,7 @@ def llama3_base_models() -> List[Model]:
]
def llama3_1_base_models() -> List[Model]:
def llama3_1_base_models() -> list[Model]:
return [
Model(
core_model_id=CoreModelId.llama3_1_8b,
@ -324,7 +323,7 @@ def llama3_1_base_models() -> List[Model]:
]
def llama3_2_base_models() -> List[Model]:
def llama3_2_base_models() -> list[Model]:
return [
Model(
core_model_id=CoreModelId.llama3_2_1b,
@ -407,7 +406,7 @@ def llama3_2_base_models() -> List[Model]:
]
def llama2_instruct_models() -> List[Model]:
def llama2_instruct_models() -> list[Model]:
return [
Model(
core_model_id=CoreModelId.llama2_7b_chat,
@ -466,7 +465,7 @@ def llama2_instruct_models() -> List[Model]:
]
def llama3_instruct_models() -> List[Model]:
def llama3_instruct_models() -> list[Model]:
return [
Model(
core_model_id=CoreModelId.llama3_8b_instruct,
@ -507,7 +506,7 @@ def llama3_instruct_models() -> List[Model]:
]
def llama3_1_instruct_models() -> List[Model]:
def llama3_1_instruct_models() -> list[Model]:
return [
Model(
core_model_id=CoreModelId.llama3_1_8b_instruct,
@ -635,7 +634,7 @@ def arch_args_3b() -> dict:
}
def llama3_2_quantized_models() -> List[Model]:
def llama3_2_quantized_models() -> list[Model]:
return [
Model(
core_model_id=CoreModelId.llama3_2_1b_instruct,
@ -704,7 +703,7 @@ def llama3_2_quantized_models() -> List[Model]:
]
def llama3_2_instruct_models() -> List[Model]:
def llama3_2_instruct_models() -> list[Model]:
return [
Model(
core_model_id=CoreModelId.llama3_2_1b_instruct,
@ -766,7 +765,7 @@ def llama3_2_instruct_models() -> List[Model]:
]
def llama3_3_instruct_models() -> List[Model]:
def llama3_3_instruct_models() -> list[Model]:
return [
Model(
core_model_id=CoreModelId.llama3_3_70b_instruct,
@ -790,7 +789,7 @@ def llama3_3_instruct_models() -> List[Model]:
@lru_cache
def safety_models() -> List[Model]:
def safety_models() -> list[Model]:
return [
Model(
core_model_id=CoreModelId.llama_guard_4_12b,
@ -919,7 +918,7 @@ def safety_models() -> List[Model]:
@dataclass
class LlamaDownloadInfo:
folder: str
files: List[str]
files: list[str]
pth_size: int

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
@ -159,13 +159,13 @@ def model_family(model_id) -> ModelFamily:
class Model(BaseModel):
core_model_id: CoreModelId
description: str
huggingface_repo: Optional[str] = None
arch_args: Dict[str, Any]
huggingface_repo: str | None = None
arch_args: dict[str, Any]
variant: str = ""
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
pth_file_count: int
metadata: Dict[str, Any] = Field(default_factory=dict)
metadata: dict[str, Any] = Field(default_factory=dict)
# silence pydantic until we remove the `model_` fields
model_config = ConfigDict(protected_namespaces=())