chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -6,7 +6,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional
class QuantizationScheme(Enum):
@ -15,8 +14,8 @@ class QuantizationScheme(Enum):
@dataclass
class QuantizationArgs:
scheme: Optional[QuantizationScheme] = None
group_size: Optional[int] = None
scheme: QuantizationScheme | None = None
group_size: int | None = None
spinquant: bool = False
def __init__(self, **kwargs):
@ -39,10 +38,10 @@ class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
n_kv_heads: int | None = None
vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
ffn_dim_multiplier: float | None = None
norm_eps: float = 1e-5
rope_theta: float = 500000
use_scaled_rope: bool = False
@ -55,8 +54,8 @@ class ModelArgs:
vision_max_num_chunks: int = 4
vision_num_cross_attention_layers: int = -1
quantization_args: Optional[QuantizationArgs] = None
lora_args: Optional[LoRAArgs] = None
quantization_args: QuantizationArgs | None = None
lora_args: LoRAArgs | None = None
def __init__(self, **kwargs):
for k, v in kwargs.items():

View file

@ -8,7 +8,6 @@ import io
import json
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from PIL import Image as PIL_Image
@ -29,14 +28,14 @@ from .tool_utils import ToolUtils
@dataclass
class VisionInput:
mask: List[List[int]]
images: List[PIL_Image.Image]
mask: list[list[int]]
images: list[PIL_Image.Image]
@dataclass
class LLMInput:
tokens: List[int]
vision: Optional[VisionInput] = None
tokens: list[int]
vision: VisionInput | None = None
def role_str(role: Role) -> str:
@ -50,7 +49,7 @@ def role_str(role: Role) -> str:
class ChatFormat:
possible_headers: Dict[Role, str]
possible_headers: dict[Role, str]
def __init__(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer
@ -58,7 +57,7 @@ class ChatFormat:
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
self.vision_token = self.tokenizer.special_tokens["<|image|>"]
def _encode_header(self, role: str) -> List[int]:
def _encode_header(self, role: str) -> list[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
@ -70,7 +69,7 @@ class ChatFormat:
tokens, images = self._encode_content(content, bos=True)
return self._model_input_from_tokens_images(tokens, images)
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]:
def _encode_content(self, content: RawContent, bos: bool = False) -> tuple[list[int], list[PIL_Image.Image]]:
tokens = []
images = []
@ -107,7 +106,7 @@ class ChatFormat:
def encode_message(
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
) -> Tuple[List[int], List[PIL_Image.Image]]:
) -> tuple[list[int], list[PIL_Image.Image]]:
tokens = self._encode_header(message.role)
images = []
@ -145,8 +144,8 @@ class ChatFormat:
def encode_dialog_prompt(
self,
messages: List[RawMessage],
tool_prompt_format: Optional[ToolPromptFormat] = None,
messages: list[RawMessage],
tool_prompt_format: ToolPromptFormat | None = None,
) -> LLMInput:
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
tokens = []
@ -163,7 +162,7 @@ class ChatFormat:
return self._model_input_from_tokens_images(tokens, images)
# TODO(this should be generic, not only for assistant messages)
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
def decode_assistant_message(self, tokens: list[int], stop_reason: StopReason) -> RawMessage:
content = self.tokenizer.decode(tokens)
return self.decode_assistant_message_from_content(content, stop_reason)
@ -234,7 +233,7 @@ class ChatFormat:
tool_calls=tool_calls,
)
def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput:
def _model_input_from_tokens_images(self, tokens: list[int], images: list[PIL_Image.Image]) -> LLMInput:
vision_input = None
if len(images) > 0:
vision_input = VisionInput(
@ -249,9 +248,9 @@ class ChatFormat:
def create_vision_mask(
tokens: List[int],
tokens: list[int],
vision_token: int,
) -> List[List[int]]:
) -> list[list[int]]:
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
if len(vision_token_locations) == 0:
return []

View file

@ -15,8 +15,8 @@ import json
import os
import sys
import time
from collections.abc import Callable, Generator
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
@ -41,8 +41,8 @@ class Llama3:
ckpt_dir: str,
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
quantization_mode: Optional[QuantizationMode] = None,
world_size: int | None = None,
quantization_mode: QuantizationMode | None = None,
seed: int = 1,
device: str = "cuda",
):
@ -82,7 +82,7 @@ class Llama3:
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f:
with open(Path(ckpt_dir) / "params.json") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
@ -154,15 +154,15 @@ class Llama3:
@torch.inference_mode()
def generate(
self,
llm_inputs: List[LLMInput],
llm_inputs: list[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
max_gen_len: int | None = None,
logprobs: bool = False,
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator[List[GenerationResult], None, None]:
logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
) -> Generator[list[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
params = self.model.params
@ -302,13 +302,13 @@ class Llama3:
def completion(
self,
contents: List[RawContent],
contents: list[RawContent],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
max_gen_len: int | None = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
) -> Generator[list[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate(
model_inputs=model_inputs,
@ -324,14 +324,14 @@ class Llama3:
def chat_completion(
self,
messages_batch: List[List[RawMessage]],
messages_batch: list[list[RawMessage]],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
max_gen_len: int | None = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
echo: bool = False,
) -> Generator[List[GenerationResult], None, None]:
) -> Generator[list[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate(
model_inputs=model_inputs,

View file

@ -12,7 +12,6 @@
# the top-level of this source tree.
from pathlib import Path
from typing import List, Optional
from termcolor import colored
@ -131,7 +130,7 @@ class LLama31Interface:
self.formatter = ChatFormat(self.tokenizer)
self.tool_prompt_format = tool_prompt_format
def get_tokens(self, messages: List[RawMessage]) -> List[int]:
def get_tokens(self, messages: list[RawMessage]) -> list[int]:
model_input = self.formatter.encode_dialog_prompt(
messages,
self.tool_prompt_format,
@ -149,10 +148,10 @@ class LLama31Interface:
def system_messages(
self,
builtin_tools: List[BuiltinTool],
custom_tools: List[ToolDefinition],
instruction: Optional[str] = None,
) -> List[RawMessage]:
builtin_tools: list[BuiltinTool],
custom_tools: list[ToolDefinition],
instruction: str | None = None,
) -> list[RawMessage]:
messages = []
default_gen = SystemDefaultGenerator()
@ -194,8 +193,8 @@ class LLama31Interface:
self,
content: str,
stop_reason: StopReason,
tool_call: Optional[ToolCall] = None,
) -> List[RawMessage]:
tool_call: ToolCall | None = None,
) -> list[RawMessage]:
tool_calls = []
if tool_call:
tool_calls.append(tool_call)
@ -208,7 +207,7 @@ class LLama31Interface:
)
]
def user_message(self, content: str) -> List[RawMessage]:
def user_message(self, content: str) -> list[RawMessage]:
return [RawMessage(role="user", content=content)]
def display_message_as_tokens(self, message: RawMessage) -> None:
@ -228,7 +227,7 @@ class LLama31Interface:
print("\n", end="")
def list_jinja_templates() -> List[Template]:
def list_jinja_templates() -> list[Template]:
return TEMPLATES

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import math
from typing import Optional, Tuple
import fairscale.nn.model_parallel.initialize as fs_init
import torch
@ -80,7 +79,7 @@ def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
@ -162,7 +161,7 @@ class Attention(nn.Module):
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
mask: torch.Tensor | None,
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
@ -204,7 +203,7 @@ class FeedForward(nn.Module):
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
ffn_dim_multiplier: float | None,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
@ -243,7 +242,7 @@ class TransformerBlock(nn.Module):
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
mask: torch.Tensor | None,
):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))

View file

@ -14,7 +14,7 @@
import math
from collections import defaultdict
from logging import getLogger
from typing import Any, Optional, Set, Tuple
from typing import Any
import torch
import torchvision.transforms as tv
@ -26,7 +26,7 @@ IMAGE_RES = 224
logger = getLogger()
class VariableSizeImageTransform(object):
class VariableSizeImageTransform:
"""
This class accepts images of any size and dynamically resize, pads and chunks it
based on the image aspect ratio and the number of image chunks we allow.
@ -75,7 +75,7 @@ class VariableSizeImageTransform(object):
self.resample = tv.InterpolationMode.BILINEAR
@staticmethod
def get_factors(n: int) -> Set[int]:
def get_factors(n: int) -> set[int]:
"""
Calculate all factors of a given number, i.e. a dividor that leaves
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
@ -145,9 +145,9 @@ class VariableSizeImageTransform(object):
@staticmethod
def get_max_res_without_distortion(
image_size: Tuple[int, int],
target_size: Tuple[int, int],
) -> Tuple[int, int]:
image_size: tuple[int, int],
target_size: tuple[int, int],
) -> tuple[int, int]:
"""
Determines the maximum resolution to which an image can be resized to without distorting its
aspect ratio, based on the target resolution.
@ -198,8 +198,8 @@ class VariableSizeImageTransform(object):
def resize_without_distortion(
self,
image: torch.Tensor,
target_size: Tuple[int, int],
max_upscaling_size: Optional[int],
target_size: tuple[int, int],
max_upscaling_size: int | None,
) -> torch.Tensor:
"""
Used to resize an image to target_resolution, without distortion.
@ -261,10 +261,10 @@ class VariableSizeImageTransform(object):
def get_best_fit(
self,
image_size: Tuple[int, int],
image_size: tuple[int, int],
possible_resolutions: torch.Tensor,
resize_to_max_canvas: bool = False,
) -> Tuple[int, int]:
) -> tuple[int, int]:
"""
Determines the best canvas possible from a list of possible resolutions to, without distortion,
resize an image to.
@ -364,7 +364,7 @@ class VariableSizeImageTransform(object):
max_num_chunks: int,
normalize_img: bool = True,
resize_to_max_canvas: bool = False,
) -> Tuple[Any, Any]:
) -> tuple[Any, Any]:
"""
Args:
image (PIL.Image): Image to be resized.

View file

@ -6,8 +6,9 @@
import logging
import math
from collections.abc import Callable
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any
import fairscale.nn.model_parallel.initialize as fs_init
import torch
@ -104,9 +105,9 @@ class ColumnParallelConv2dPatch(torch.nn.Module):
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
bias: Optional[bool] = False,
kernel_size: int | tuple[int, int],
stride: int | tuple[int, int],
bias: bool | None = False,
) -> None:
super().__init__()
if isinstance(kernel_size, int):
@ -390,13 +391,13 @@ class VisionEncoder(nn.Module):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool = True,
missing_keys: List[str] = None,
unexpected_keys: List[str] = None,
error_msgs: List[str] = None,
missing_keys: list[str] = None,
unexpected_keys: list[str] = None,
error_msgs: list[str] = None,
return_state_dict: bool = False,
) -> None:
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
@ -641,7 +642,7 @@ class FeedForward(nn.Module):
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
ffn_dim_multiplier: float | None,
):
"""
Initialize the FeedForward module.
@ -983,7 +984,7 @@ class CrossAttentionTransformerBlock(torch.nn.Module):
self,
x: torch.Tensor,
xattn_mask: torch.Tensor,
full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor],
full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor],
xattn_cache: torch.Tensor,
) -> torch.Tensor:
_attn_out = self.attention(
@ -1144,7 +1145,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
def _init_fusion_schedule(
self,
num_layers: int,
) -> List[int]:
) -> list[int]:
llama_layers = list(range(self.n_llama_layers))
# uniformly spread the layers
@ -1231,7 +1232,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
text_dtype,
vision_tokens,
cross_attention_masks,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
assert vision_tokens is not None, "Vision tokens must be provided"
vision_seqlen = vision_tokens.shape[3]
assert vision_tokens.shape[1] == cross_attention_masks.shape[2], (
@ -1280,11 +1281,11 @@ class CrossAttentionTransformer(torch.nn.Module):
def compute_vision_tokens_masks(
self,
batch_images: List[List[PIL_Image.Image]],
batch_masks: List[List[List[int]]],
batch_images: list[list[PIL_Image.Image]],
batch_masks: list[list[list[int]]],
total_len: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
skip_vision_encoder = False
assert len(batch_images) == len(batch_masks), "Images and masks must have the same length"
@ -1371,11 +1372,11 @@ class CrossAttentionTransformer(torch.nn.Module):
def _stack_images(
images: List[List[PIL_Image.Image]],
images: list[list[PIL_Image.Image]],
max_num_chunks: int,
image_res: int,
max_num_images: int,
) -> Tuple[torch.Tensor, List[int]]:
) -> tuple[torch.Tensor, list[int]]:
"""
Takes a list of list of images and stacks them into a tensor.
This function is needed since images can be of completely
@ -1400,8 +1401,8 @@ def _stack_images(
def _pad_masks(
all_masks: List[List[List[int]]],
all_num_chunks: List[List[int]],
all_masks: list[list[list[int]]],
all_num_chunks: list[list[int]],
total_len: int,
max_num_chunks: int,
) -> torch.Tensor:

View file

@ -12,7 +12,7 @@
# the top-level of this source tree.
from dataclasses import dataclass
from typing import Any, Dict, List
from typing import Any
from jinja2 import Template
@ -20,7 +20,7 @@ from jinja2 import Template
@dataclass
class PromptTemplate:
template: str
data: Dict[str, Any]
data: dict[str, Any]
def render(self):
template = Template(self.template)
@ -35,5 +35,5 @@ class PromptTemplateGeneratorBase:
def gen(self, *args, **kwargs) -> PromptTemplate:
raise NotImplementedError()
def data_examples(self) -> List[Any]:
def data_examples(self) -> list[Any]:
raise NotImplementedError()

View file

@ -13,7 +13,7 @@
import textwrap
from datetime import datetime
from typing import Any, List, Optional
from typing import Any
from llama_stack.apis.inference import (
BuiltinTool,
@ -39,12 +39,12 @@ class SystemDefaultGenerator(PromptTemplateGeneratorBase):
},
)
def data_examples(self) -> List[Any]:
def data_examples(self) -> list[Any]:
return [None]
class BuiltinToolGenerator(PromptTemplateGeneratorBase):
def _tool_breakdown(self, tools: List[ToolDefinition]):
def _tool_breakdown(self, tools: list[ToolDefinition]):
builtin_tools, custom_tools = [], []
for dfn in tools:
if isinstance(dfn.tool_name, BuiltinTool):
@ -54,7 +54,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
return builtin_tools, custom_tools
def gen(self, tools: List[ToolDefinition]) -> PromptTemplate:
def gen(self, tools: list[ToolDefinition]) -> PromptTemplate:
builtin_tools, custom_tools = self._tool_breakdown(tools)
template_str = textwrap.dedent(
"""
@ -75,7 +75,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
},
)
def data_examples(self) -> List[List[ToolDefinition]]:
def data_examples(self) -> list[list[ToolDefinition]]:
return [
# builtin tools
[
@ -91,7 +91,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
def gen(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
Answer the user's question by making use of the following functions if needed.
@ -137,7 +137,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
{"custom_tools": [t.model_dump() for t in custom_tools]},
)
def data_examples(self) -> List[List[ToolDefinition]]:
def data_examples(self) -> list[list[ToolDefinition]]:
return [
[
ToolDefinition(
@ -161,7 +161,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
def gen(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
You have access to the following functions:
@ -199,7 +199,7 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
{"custom_tools": [t.model_dump() for t in custom_tools]},
)
def data_examples(self) -> List[List[ToolDefinition]]:
def data_examples(self) -> list[list[ToolDefinition]]:
return [
[
ToolDefinition(
@ -238,14 +238,14 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
""".strip("\n")
)
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
def gen(self, custom_tools: list[ToolDefinition], system_prompt: str | None = None) -> PromptTemplate:
system_prompt = system_prompt or self.DEFAULT_PROMPT
return PromptTemplate(
system_prompt,
{"function_description": self._gen_function_description(custom_tools)},
)
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
Here is a list of functions in JSON format that you can invoke.
@ -291,7 +291,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{"tools": [t.model_dump() for t in custom_tools]},
).render()
def data_examples(self) -> List[List[ToolDefinition]]:
def data_examples(self) -> list[list[ToolDefinition]]:
return [
[
ToolDefinition(

View file

@ -12,7 +12,6 @@
# the top-level of this source tree.
import textwrap
from typing import Optional
from .base import PromptTemplate, PromptTemplateGeneratorBase
@ -21,8 +20,8 @@ class ToolResponseGenerator(PromptTemplateGeneratorBase):
def gen(
self,
status: str,
stdout: Optional[str] = None,
stderr: Optional[str] = None,
stdout: str | None = None,
stderr: str | None = None,
):
assert status in [
"success",

View file

@ -6,7 +6,7 @@
# type: ignore
import os
from typing import Any, Dict, List, Optional, cast
from typing import Any, cast
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
@ -37,9 +37,9 @@ def swiglu_wrapper(
def convert_to_quantized_model(
model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str,
quantization_mode: Optional[str] = None,
fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
quantization_mode: str | None = None,
fp8_activation_scale_ub: float | None = 1200.0,
device: torch.device | None = None,
) -> Transformer | CrossAttentionTransformer:
if quantization_mode == QuantizationMode.fp8_mixed:
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
@ -52,8 +52,8 @@ def convert_to_quantized_model(
def convert_to_fp8_quantized_model(
model: Transformer,
checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
fp8_activation_scale_ub: float | None = 1200.0,
device: torch.device | None = None,
) -> Transformer:
# Move weights to GPU with quantization
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
@ -122,8 +122,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
# LoRA parameters
lora_rank: Optional[int] = None,
lora_scale: Optional[float] = None,
lora_rank: int | None = None,
lora_scale: float | None = None,
) -> None:
super().__init__(
in_features,
@ -134,8 +134,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
precision=precision,
scales_precision=scales_precision,
)
self.lora_scale: Optional[float] = None
self.adaptor: Optional[nn.Sequential] = None
self.lora_scale: float | None = None
self.adaptor: nn.Sequential | None = None
if lora_rank is not None:
assert lora_scale is not None, "Please specify lora scale for LoRA."
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
@ -147,13 +147,13 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
"""A hook to load the quantized weights from the state dict."""
if prefix + "zeros" not in state_dict:
@ -191,13 +191,13 @@ class Int8WeightEmbedding(torch.nn.Embedding):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
"""A hook to load the quantized embedding weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
@ -221,13 +221,13 @@ class Int8WeightLinear(torch.nn.Linear):
def load_hook(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
local_metadata: dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
missing_keys: list[str],
unexpected_keys: list[str],
error_msgs: list[str],
) -> None:
"""A hook to load the quantized linear weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
@ -238,8 +238,8 @@ class Int8WeightLinear(torch.nn.Linear):
def _prepare_model_int4_weight_int8_dynamic_activation(
model: torch.nn.Module,
group_size: int,
lora_rank: Optional[int],
lora_scale: Optional[float],
lora_rank: int | None,
lora_scale: float | None,
):
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
@ -265,7 +265,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
)
del module
setattr(model, module_name, quantized_module)
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
elif isinstance(module, ColumnParallelLinear | RowParallelLinear | nn.Linear):
quantized_module = Int8DynActInt4WeightLinearLoRA(
in_features=module.in_features,
out_features=module.out_features,
@ -286,7 +286,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
def convert_to_int4_quantized_model(
model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str,
device: Optional[torch.device] = None,
device: torch.device | None = None,
) -> Transformer | CrossAttentionTransformer:
"""Convert the model to int4 quantized model."""
model_args = model.params

View file

@ -5,18 +5,11 @@
# the root directory of this source tree.
import os
from collections.abc import Collection, Iterator, Sequence, Set
from logging import getLogger
from pathlib import Path
from typing import (
AbstractSet,
Collection,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Union,
cast,
)
@ -44,7 +37,7 @@ class Tokenizer:
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
special_tokens: Dict[str, int]
special_tokens: dict[str, int]
num_reserved_special_tokens = 256
@ -116,9 +109,9 @@ class Tokenizer:
*,
bos: bool,
eos: bool,
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
disallowed_special: Union[Literal["all"], Collection[str]] = (),
) -> List[int]:
allowed_special: Literal["all"] | Set[str] | None = None,
disallowed_special: Literal["all"] | Collection[str] = (),
) -> list[int]:
"""
Encodes a string into a list of token IDs.
@ -151,7 +144,7 @@ class Tokenizer:
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
)
)
t: List[int] = []
t: list[int] = []
for substr in substrs:
t.extend(
self.model.encode(
@ -177,7 +170,7 @@ class Tokenizer:
str: The decoded string.
"""
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(List[int], t))
return self.model.decode(cast(list[int], t))
@staticmethod
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:

View file

@ -6,7 +6,6 @@
import json
import re
from typing import Optional, Tuple
from llama_stack.log import get_logger
@ -172,7 +171,7 @@ class ToolUtils:
return match is not None
@staticmethod
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
def maybe_extract_builtin_tool_call(message_body: str) -> tuple[str, str] | None:
# Find the first match in the text
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
@ -185,7 +184,7 @@ class ToolUtils:
return None
@staticmethod
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
def maybe_extract_custom_tool_call(message_body: str) -> tuple[str, str] | None:
# NOTE: Custom function too calls are still experimental
# Sometimes, response is of the form
# {"type": "function", "name": "function_name", "parameters": {...}
@ -252,7 +251,7 @@ class ToolUtils:
def format_value(value: RecursiveType) -> str:
if isinstance(value, str):
return f'"{value}"'
elif isinstance(value, (int, float, bool)) or value is None:
elif isinstance(value, int | float | bool) or value is None:
return str(value)
elif isinstance(value, list):
return f"[{', '.join(format_value(v) for v in value)}]"