forked from phoenix-oss/llama-stack-mirror
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
|
@ -80,7 +79,7 @@ def apply_rotary_emb(
|
|||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||
|
@ -162,7 +161,7 @@ class Attention(nn.Module):
|
|||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
mask: torch.Tensor | None,
|
||||
):
|
||||
bsz, seqlen, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
@ -204,7 +203,7 @@ class FeedForward(nn.Module):
|
|||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float],
|
||||
ffn_dim_multiplier: float | None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
|
@ -243,7 +242,7 @@ class TransformerBlock(nn.Module):
|
|||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
mask: torch.Tensor | None,
|
||||
):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue