mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -10,8 +10,8 @@ import json
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable, Generator
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -38,8 +38,8 @@ class Llama4:
|
|||
ckpt_dir: str,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
world_size: Optional[int] = None,
|
||||
quantization_mode: Optional[QuantizationMode] = None,
|
||||
world_size: int | None = None,
|
||||
quantization_mode: QuantizationMode | None = None,
|
||||
seed: int = 1,
|
||||
):
|
||||
if not torch.distributed.is_initialized():
|
||||
|
@ -63,7 +63,7 @@ class Llama4:
|
|||
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
with open(Path(ckpt_dir) / "params.json") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
|
@ -117,15 +117,15 @@ class Llama4:
|
|||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
llm_inputs: List[LLMInput],
|
||||
llm_inputs: list[LLMInput],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
max_gen_len: int | None = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
print_model_input: bool = False,
|
||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
) -> Generator[list[GenerationResult], None, None]:
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
|
||||
max_gen_len = self.model.args.max_seq_len - 1
|
||||
|
||||
|
@ -245,13 +245,13 @@ class Llama4:
|
|||
|
||||
def completion(
|
||||
self,
|
||||
contents: List[RawContent],
|
||||
contents: list[RawContent],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
max_gen_len: int | None = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
) -> Generator[list[GenerationResult], None, None]:
|
||||
llm_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||
for result in self.generate(
|
||||
llm_inputs=llm_inputs,
|
||||
|
@ -267,13 +267,13 @@ class Llama4:
|
|||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages_batch: List[List[RawMessage]],
|
||||
messages_batch: list[list[RawMessage]],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
max_gen_len: int | None = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
) -> Generator[list[GenerationResult], None, None]:
|
||||
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||
for result in self.generate(
|
||||
llm_inputs=llm_inputs,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue