forked from phoenix-oss/llama-stack-mirror
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
@ -15,8 +14,8 @@ class QuantizationScheme(Enum):
|
|||
|
||||
|
||||
class QuantizationArgs(BaseModel):
|
||||
scheme: Optional[QuantizationScheme] = None
|
||||
group_size: Optional[int] = None
|
||||
scheme: QuantizationScheme | None = None
|
||||
group_size: int | None = None
|
||||
spinquant: bool = False
|
||||
|
||||
|
||||
|
@ -58,32 +57,32 @@ class ModelArgs(BaseModel):
|
|||
dim: int = -1
|
||||
n_layers: int = -1
|
||||
n_heads: int = -1
|
||||
n_kv_heads: Optional[int] = None
|
||||
head_dim: Optional[int] = None
|
||||
n_kv_heads: int | None = None
|
||||
head_dim: int | None = None
|
||||
|
||||
vocab_size: int = -1
|
||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||
ffn_dim_multiplier: Optional[float] = None
|
||||
ffn_exp: Optional[float] = None
|
||||
ffn_dim_multiplier: float | None = None
|
||||
ffn_exp: float | None = None
|
||||
norm_eps: float = 1e-5
|
||||
|
||||
attention_chunk_size: Optional[int] = None
|
||||
attention_chunk_size: int | None = None
|
||||
rope_theta: float = 500000
|
||||
use_scaled_rope: bool = False
|
||||
rope_scaling_factor: Optional[float] = None
|
||||
rope_high_freq_factor: Optional[float] = None
|
||||
rope_scaling_factor: float | None = None
|
||||
rope_high_freq_factor: float | None = None
|
||||
|
||||
nope_layer_interval: Optional[int] = None # No position encoding in every n layers
|
||||
nope_layer_interval: int | None = None # No position encoding in every n layers
|
||||
use_qk_norm: bool = False
|
||||
# Set to True to enable inference-time temperature tuning (useful for very long context)
|
||||
attn_temperature_tuning: bool = False
|
||||
floor_scale: float = 8192.0
|
||||
attn_scale: float = 0.1
|
||||
|
||||
vision_args: Optional[VisionArgs] = None
|
||||
moe_args: Optional[MoEArgs] = None
|
||||
quantization_args: Optional[QuantizationArgs] = None
|
||||
lora_args: Optional[LoRAArgs] = None
|
||||
vision_args: VisionArgs | None = None
|
||||
moe_args: MoEArgs | None = None
|
||||
quantization_args: QuantizationArgs | None = None
|
||||
lora_args: LoRAArgs | None = None
|
||||
|
||||
max_batch_size: int = 32
|
||||
max_seq_len: int = 2048
|
||||
|
|
|
@ -8,7 +8,6 @@ import io
|
|||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image as PIL_Image
|
||||
|
@ -46,10 +45,10 @@ def role_str(role: Role) -> str:
|
|||
class TransformedImage:
|
||||
image_tiles: torch.Tensor
|
||||
# is the aspect ratio needed anywhere?
|
||||
aspect_ratio: Tuple[int, int]
|
||||
aspect_ratio: tuple[int, int]
|
||||
|
||||
|
||||
def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
||||
def convert_image_to_rgb(image: PIL_Image.Image, bg: tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
||||
if image.mode == "RGBA":
|
||||
image.load() # for png.split()
|
||||
new_img = PIL_Image.new("RGB", image.size, bg)
|
||||
|
@ -59,12 +58,12 @@ def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255
|
|||
|
||||
|
||||
class ChatFormat:
|
||||
possible_headers: Dict[Role, str]
|
||||
possible_headers: dict[Role, str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Tokenizer,
|
||||
vision_args: Optional[VisionArgs] = None,
|
||||
vision_args: VisionArgs | None = None,
|
||||
max_num_chunks: int = 16,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
|
@ -81,7 +80,7 @@ class ChatFormat:
|
|||
vision_args.image_size.width, vision_args.image_size.height
|
||||
)
|
||||
|
||||
def _encode_header(self, role: str) -> List[int]:
|
||||
def _encode_header(self, role: str) -> list[int]:
|
||||
tokens = []
|
||||
tokens.append(self.tokenizer.special_tokens["<|header_start|>"])
|
||||
|
||||
|
@ -98,7 +97,7 @@ class ChatFormat:
|
|||
def _encode_image(
|
||||
self,
|
||||
transformed_image: TransformedImage,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
assert self.vision_args is not None, "The model is not vision-enabled"
|
||||
|
||||
image_tensor = transformed_image.image_tiles
|
||||
|
@ -140,7 +139,7 @@ class ChatFormat:
|
|||
|
||||
return tokens
|
||||
|
||||
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[TransformedImage]]:
|
||||
def _encode_content(self, content: RawContent, bos: bool = False) -> tuple[list[int], list[TransformedImage]]:
|
||||
tokens = []
|
||||
tranformed_images = []
|
||||
|
||||
|
@ -189,7 +188,7 @@ class ChatFormat:
|
|||
|
||||
def encode_message(
|
||||
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
||||
) -> Tuple[List[int], List[TransformedImage]]:
|
||||
) -> tuple[list[int], list[TransformedImage]]:
|
||||
tokens = self._encode_header(message.role)
|
||||
images = []
|
||||
|
||||
|
@ -223,7 +222,7 @@ class ChatFormat:
|
|||
|
||||
def encode_dialog_prompt(
|
||||
self,
|
||||
messages: List[RawMessage],
|
||||
messages: list[RawMessage],
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
) -> LLMInput:
|
||||
tokens = []
|
||||
|
@ -240,7 +239,7 @@ class ChatFormat:
|
|||
return self._model_input_from_tokens_images(tokens, images)
|
||||
|
||||
# TODO(this should be generic, not only for assistant messages)
|
||||
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
|
||||
def decode_assistant_message(self, tokens: list[int], stop_reason: StopReason) -> RawMessage:
|
||||
content = self.tokenizer.decode(tokens)
|
||||
|
||||
return self.decode_assistant_message_from_content(content, stop_reason)
|
||||
|
@ -312,7 +311,7 @@ class ChatFormat:
|
|||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
def _model_input_from_tokens_images(self, tokens: List[int], images: List[TransformedImage]) -> LLMInput:
|
||||
def _model_input_from_tokens_images(self, tokens: list[int], images: list[TransformedImage]) -> LLMInput:
|
||||
return LLMInput(
|
||||
tokens=tokens,
|
||||
images=[x.image_tiles for x in images] if len(images) > 0 else None,
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -30,7 +29,7 @@ class LLMInput:
|
|||
tokens: torch.Tensor
|
||||
|
||||
# images are already pre-processed (resized, tiled, etc.)
|
||||
images: Optional[List[torch.Tensor]] = None
|
||||
images: list[torch.Tensor] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -45,8 +44,8 @@ class TransformerInput:
|
|||
# tokens_position defines the position of the tokens in each batch,
|
||||
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
|
||||
# - when it is an int, the start position are the same for all batches
|
||||
tokens_position: Union[torch.Tensor, int]
|
||||
image_embedding: Optional[MaskedEmbedding] = None
|
||||
tokens_position: torch.Tensor | int
|
||||
image_embedding: MaskedEmbedding | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
|
@ -36,13 +36,13 @@ class FeedForward(nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
if prefix + "mlp.fc1_weight" in state_dict:
|
||||
w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0)
|
||||
|
|
|
@ -10,8 +10,8 @@ import json
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable, Generator
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -38,8 +38,8 @@ class Llama4:
|
|||
ckpt_dir: str,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
world_size: Optional[int] = None,
|
||||
quantization_mode: Optional[QuantizationMode] = None,
|
||||
world_size: int | None = None,
|
||||
quantization_mode: QuantizationMode | None = None,
|
||||
seed: int = 1,
|
||||
):
|
||||
if not torch.distributed.is_initialized():
|
||||
|
@ -63,7 +63,7 @@ class Llama4:
|
|||
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
with open(Path(ckpt_dir) / "params.json") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
|
@ -117,15 +117,15 @@ class Llama4:
|
|||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
llm_inputs: List[LLMInput],
|
||||
llm_inputs: list[LLMInput],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
max_gen_len: int | None = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
print_model_input: bool = False,
|
||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
) -> Generator[list[GenerationResult], None, None]:
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
|
||||
max_gen_len = self.model.args.max_seq_len - 1
|
||||
|
||||
|
@ -245,13 +245,13 @@ class Llama4:
|
|||
|
||||
def completion(
|
||||
self,
|
||||
contents: List[RawContent],
|
||||
contents: list[RawContent],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
max_gen_len: int | None = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
) -> Generator[list[GenerationResult], None, None]:
|
||||
llm_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||
for result in self.generate(
|
||||
llm_inputs=llm_inputs,
|
||||
|
@ -267,13 +267,13 @@ class Llama4:
|
|||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages_batch: List[List[RawMessage]],
|
||||
messages_batch: list[list[RawMessage]],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
max_gen_len: int | None = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
) -> Generator[list[GenerationResult], None, None]:
|
||||
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||
for result in self.generate(
|
||||
llm_inputs=llm_inputs,
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
|
@ -89,7 +89,7 @@ def apply_rotary_emb(
|
|||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||
|
@ -174,13 +174,13 @@ class Attention(nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
if prefix + "wqkv.weight" in state_dict:
|
||||
wqkv = state_dict.pop(prefix + "wqkv.weight")
|
||||
|
@ -200,7 +200,7 @@ class Attention(nn.Module):
|
|||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
mask: torch.Tensor | None = None,
|
||||
):
|
||||
bsz, seqlen, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
@ -288,13 +288,13 @@ class TransformerBlock(nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
|
||||
state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight")
|
||||
|
@ -318,8 +318,8 @@ class TransformerBlock(nn.Module):
|
|||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
global_attn_mask: Optional[torch.Tensor],
|
||||
local_attn_mask: Optional[torch.Tensor],
|
||||
global_attn_mask: torch.Tensor | None,
|
||||
local_attn_mask: torch.Tensor | None,
|
||||
):
|
||||
# The iRoPE architecture uses global attention mask for NoPE layers or
|
||||
# if chunked local attention is not used
|
||||
|
@ -374,13 +374,13 @@ class Transformer(nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
if prefix + "rope.freqs" in state_dict:
|
||||
state_dict.pop(prefix + "rope.freqs")
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
# ruff: noqa: N806
|
||||
# pyre-strict
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
|
@ -63,13 +63,13 @@ class Experts(nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
self.prefix = prefix
|
||||
if prefix + "moe_w_in_eD_F" in state_dict:
|
||||
|
@ -158,13 +158,13 @@ class MoE(torch.nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
if prefix + "w_in_shared_FD.weight" in state_dict:
|
||||
state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
|
||||
|
@ -210,5 +210,5 @@ class MoE(torch.nn.Module):
|
|||
|
||||
|
||||
def divide_exact(numerator: int, denominator: int) -> int:
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
|
||||
assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
|
||||
return numerator // denominator
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv
|
||||
|
@ -52,7 +51,7 @@ class ResizeNormalizeImageTransform:
|
|||
return self.tv_transform(image)
|
||||
|
||||
|
||||
class VariableSizeImageTransform(object):
|
||||
class VariableSizeImageTransform:
|
||||
"""
|
||||
This class accepts images of any size and dynamically resize, pads and chunks it
|
||||
based on the image aspect ratio and the number of image chunks we allow.
|
||||
|
@ -100,7 +99,7 @@ class VariableSizeImageTransform(object):
|
|||
self.resample = tv.InterpolationMode.BILINEAR
|
||||
|
||||
@staticmethod
|
||||
def get_factors(n: int) -> Set[int]:
|
||||
def get_factors(n: int) -> set[int]:
|
||||
"""
|
||||
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||
|
@ -170,9 +169,9 @@ class VariableSizeImageTransform(object):
|
|||
|
||||
@staticmethod
|
||||
def get_max_res_without_distortion(
|
||||
image_size: Tuple[int, int],
|
||||
target_size: Tuple[int, int],
|
||||
) -> Tuple[int, int]:
|
||||
image_size: tuple[int, int],
|
||||
target_size: tuple[int, int],
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||
aspect ratio, based on the target resolution.
|
||||
|
@ -223,8 +222,8 @@ class VariableSizeImageTransform(object):
|
|||
def resize_without_distortion(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
target_size: Tuple[int, int],
|
||||
max_upscaling_size: Optional[int],
|
||||
target_size: tuple[int, int],
|
||||
max_upscaling_size: int | None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Used to resize an image to target_resolution, without distortion.
|
||||
|
@ -289,10 +288,10 @@ class VariableSizeImageTransform(object):
|
|||
|
||||
def get_best_fit(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
image_size: tuple[int, int],
|
||||
possible_resolutions: torch.Tensor,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[int, int]:
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||
resize an image to.
|
||||
|
@ -392,7 +391,7 @@ class VariableSizeImageTransform(object):
|
|||
max_num_chunks: int,
|
||||
normalize_img: bool = True,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
) -> tuple[torch.Tensor, tuple[int, int]]:
|
||||
"""
|
||||
Args:
|
||||
image (PIL.Image): Image to be resized.
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
||||
|
@ -67,14 +66,14 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
""".strip("\n")
|
||||
)
|
||||
|
||||
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
|
||||
def gen(self, custom_tools: list[ToolDefinition], system_prompt: str | None = None) -> PromptTemplate:
|
||||
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
||||
return PromptTemplate(
|
||||
system_prompt,
|
||||
{"function_description": self._gen_function_description(custom_tools)},
|
||||
)
|
||||
|
||||
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
@ -120,7 +119,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
{"tools": [t.model_dump() for t in custom_tools]},
|
||||
).render()
|
||||
|
||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||
def data_examples(self) -> list[list[ToolDefinition]]:
|
||||
return [
|
||||
[
|
||||
ToolDefinition(
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import textwrap
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||
PythonListCustomToolGenerator,
|
||||
|
@ -23,7 +22,7 @@ from ..prompt_format import (
|
|||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def usecases(base_model: bool = False) -> List[UseCase | str]:
|
||||
def usecases(base_model: bool = False) -> list[UseCase | str]:
|
||||
with open(THIS_DIR.parent / "resources/small_dog.jpg", "rb") as f:
|
||||
img_small_dog = f.read()
|
||||
with open(THIS_DIR.parent / "resources/dog.jpg", "rb") as f:
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||
|
@ -45,8 +45,8 @@ def experts_batched_swiglu_wrapper(
|
|||
def convert_to_quantized_model(
|
||||
model: Transformer,
|
||||
checkpoint_dir: str,
|
||||
quantization_mode: Optional[str] = None,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
quantization_mode: str | None = None,
|
||||
fp8_activation_scale_ub: float | None = 1200.0,
|
||||
use_rich_progress: bool = True,
|
||||
) -> Transformer:
|
||||
from ...quantize_impls import (
|
||||
|
@ -213,7 +213,7 @@ def logging_callbacks(
|
|||
)
|
||||
task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting")
|
||||
|
||||
def update_status(message: Optional[str], completed: Optional[int] = None) -> None:
|
||||
def update_status(message: str | None, completed: int | None = None) -> None:
|
||||
if use_rich_progress:
|
||||
if message is not None:
|
||||
progress.update(task_id, status=message)
|
||||
|
|
|
@ -5,18 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from collections.abc import Collection, Iterator, Sequence, Set
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
|
@ -114,7 +107,7 @@ class Tokenizer:
|
|||
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
||||
"""
|
||||
|
||||
special_tokens: Dict[str, int]
|
||||
special_tokens: dict[str, int]
|
||||
|
||||
num_reserved_special_tokens = 2048
|
||||
|
||||
|
@ -182,9 +175,9 @@ class Tokenizer:
|
|||
*,
|
||||
bos: bool,
|
||||
eos: bool,
|
||||
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = (),
|
||||
) -> List[int]:
|
||||
allowed_special: Literal["all"] | Set[str] | None = None,
|
||||
disallowed_special: Literal["all"] | Collection[str] = (),
|
||||
) -> list[int]:
|
||||
"""
|
||||
Encodes a string into a list of token IDs.
|
||||
|
||||
|
@ -217,7 +210,7 @@ class Tokenizer:
|
|||
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||
)
|
||||
)
|
||||
t: List[int] = []
|
||||
t: list[int] = []
|
||||
for substr in substrs:
|
||||
t.extend(
|
||||
self.model.encode(
|
||||
|
@ -243,7 +236,7 @@ class Tokenizer:
|
|||
str: The decoded string.
|
||||
"""
|
||||
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
||||
return self.model.decode(cast(List[int], t))
|
||||
return self.model.decode(cast(list[int], t))
|
||||
|
||||
@staticmethod
|
||||
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -136,13 +137,13 @@ class VisionEmbeddings(torch.nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool = True,
|
||||
missing_keys: List[str] = None,
|
||||
unexpected_keys: List[str] = None,
|
||||
error_msgs: List[str] = None,
|
||||
missing_keys: list[str] = None,
|
||||
unexpected_keys: list[str] = None,
|
||||
error_msgs: list[str] = None,
|
||||
return_state_dict: bool = False,
|
||||
) -> None:
|
||||
original_sd = self.state_dict()
|
||||
|
@ -163,7 +164,7 @@ class VisionEmbeddings(torch.nn.Module):
|
|||
# each image is a tensor of shape [num_tiles, C, H, W]
|
||||
def forward(
|
||||
self,
|
||||
image_batch: List[List[torch.Tensor]],
|
||||
image_batch: list[list[torch.Tensor]],
|
||||
image_mask: torch.Tensor,
|
||||
h_ref: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
|
@ -42,9 +43,9 @@ class ColumnParallelConv2dPatch(torch.nn.Module):
|
|||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]],
|
||||
bias: Optional[bool] = False,
|
||||
kernel_size: int | tuple[int, int],
|
||||
stride: int | tuple[int, int],
|
||||
bias: bool | None = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
|
@ -134,15 +135,15 @@ class _TransformerBlock(nn.Module):
|
|||
def attention(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freq_cis: Optional[torch.Tensor] = None,
|
||||
freq_cis: torch.Tensor | None = None,
|
||||
):
|
||||
return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
freq_cis: Optional[torch.Tensor] = None,
|
||||
mask: torch.Tensor | None = None,
|
||||
freq_cis: torch.Tensor | None = None,
|
||||
):
|
||||
_gate_attn = 1 if not self.gated else self.gate_attn.tanh()
|
||||
_gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
|
||||
|
@ -210,8 +211,8 @@ class PackingIndex:
|
|||
class VisionEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
patch_size: Tuple[int, int],
|
||||
image_size: tuple[int, int],
|
||||
patch_size: tuple[int, int],
|
||||
dim: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
|
@ -299,13 +300,13 @@ class VisionEncoder(nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool = True,
|
||||
missing_keys: List[str] = None,
|
||||
unexpected_keys: List[str] = None,
|
||||
error_msgs: List[str] = None,
|
||||
missing_keys: list[str] = None,
|
||||
unexpected_keys: list[str] = None,
|
||||
error_msgs: list[str] = None,
|
||||
return_state_dict: bool = False,
|
||||
) -> None:
|
||||
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue