mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -7,7 +7,6 @@
|
|||
# type: ignore
|
||||
import collections
|
||||
import logging
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -27,7 +26,7 @@ class Fp8ScaledWeights:
|
|||
# TODO: Ugly trick so torch allows us to replace parameters
|
||||
# with our custom Fp8Weights instance. Do this properly.
|
||||
@property
|
||||
def __class__(self) -> Type[nn.parameter.Parameter]:
|
||||
def __class__(self) -> type[nn.parameter.Parameter]:
|
||||
return nn.Parameter
|
||||
|
||||
@property
|
||||
|
@ -51,7 +50,7 @@ class Int4ScaledWeights:
|
|||
# TODO: Ugly trick so torch allows us to replace parameters
|
||||
# with our custom Int4Weights instance. Do this properly.
|
||||
@property
|
||||
def __class__(self) -> Type[nn.parameter.Parameter]:
|
||||
def __class__(self) -> type[nn.parameter.Parameter]:
|
||||
return nn.Parameter
|
||||
|
||||
@property
|
||||
|
@ -74,7 +73,7 @@ class Int4Weights(
|
|||
def int4_row_quantize(
|
||||
x: torch.Tensor,
|
||||
group_size: int = 128,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
n_bit = 4 # Number of target bits.
|
||||
to_quant = x.reshape(-1, group_size).to(torch.float)
|
||||
|
||||
|
@ -115,8 +114,8 @@ def pack_int4(x: torch.Tensor) -> torch.Tensor:
|
|||
|
||||
def bmm_nt(
|
||||
x: Tensor,
|
||||
w: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
w: Fp8RowwiseWeights | Int4Weights,
|
||||
num_tokens: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if isinstance(w, Fp8ScaledWeights):
|
||||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub)
|
||||
|
@ -129,10 +128,10 @@ def bmm_nt(
|
|||
|
||||
def ffn_swiglu(
|
||||
x: Tensor,
|
||||
w1: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w3: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w2: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
w1: Fp8RowwiseWeights | Int4Weights,
|
||||
w3: Fp8RowwiseWeights | Int4Weights,
|
||||
w2: Fp8RowwiseWeights | Int4Weights,
|
||||
num_tokens: Tensor | None = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or (
|
||||
|
@ -158,7 +157,7 @@ def ffn_swiglu(
|
|||
def quantize_fp8(
|
||||
w: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
output_device: torch.device | None = None,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Quantize [n, k] weight tensor.
|
||||
|
||||
|
@ -184,7 +183,7 @@ def quantize_fp8(
|
|||
@torch.inference_mode()
|
||||
def quantize_int4(
|
||||
w: Tensor,
|
||||
output_device: Optional[torch.device] = None,
|
||||
output_device: torch.device | None = None,
|
||||
) -> Int4Weights:
|
||||
"""Quantize [n, k/2] weight tensor.
|
||||
|
||||
|
@ -213,7 +212,7 @@ def load_fp8(
|
|||
w: Tensor,
|
||||
w_scale: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
output_device: torch.device | None = None,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Load FP8 [n, k] weight tensor.
|
||||
|
||||
|
@ -239,7 +238,7 @@ def load_int4(
|
|||
w: Tensor,
|
||||
scale: Tensor,
|
||||
zero_point: Tensor,
|
||||
output_device: Optional[torch.device] = None,
|
||||
output_device: torch.device | None = None,
|
||||
) -> Int4Weights:
|
||||
"""Load INT4 [n, k/2] weight tensor.
|
||||
|
||||
|
@ -256,9 +255,9 @@ def load_int4(
|
|||
|
||||
def fc_dynamic(
|
||||
x: Tensor,
|
||||
w: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
w: Fp8RowwiseWeights | Int4Weights,
|
||||
activation_scale_ub: Tensor | None = None,
|
||||
num_tokens: Tensor | None = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
"""
|
||||
|
@ -275,11 +274,11 @@ def fc_dynamic(
|
|||
|
||||
def ffn_swiglu_dynamic(
|
||||
x: Tensor,
|
||||
w1: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w3: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w2: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
w1: Fp8RowwiseWeights | Int4Weights,
|
||||
w3: Fp8RowwiseWeights | Int4Weights,
|
||||
w2: Fp8RowwiseWeights | Int4Weights,
|
||||
activation_scale_ub: Tensor | None = None,
|
||||
num_tokens: Tensor | None = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
assert x.dim() == 3 or x.dim() == 2
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue