chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -8,7 +8,7 @@ import json
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List
from typing import Any
import torch
from safetensors.torch import save_file
@ -34,7 +34,7 @@ class TorchtuneCheckpointer:
model_id: str,
training_algorithm: str,
checkpoint_dir: str,
checkpoint_files: List[str],
checkpoint_files: list[str],
output_dir: str,
model_type: str,
):
@ -54,11 +54,11 @@ class TorchtuneCheckpointer:
# get ckpt paths
self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file)
def load_checkpoint(self) -> Dict[str, Any]:
def load_checkpoint(self) -> dict[str, Any]:
"""
Load Meta checkpoint from file. Currently only loading from a single file is supported.
"""
state_dict: Dict[str, Any] = {}
state_dict: dict[str, Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
@ -82,7 +82,7 @@ class TorchtuneCheckpointer:
def save_checkpoint(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
epoch: int,
adapter_only: bool = False,
checkpoint_format: str | None = None,
@ -100,7 +100,7 @@ class TorchtuneCheckpointer:
def _save_meta_format_checkpoint(
self,
model_file_path: Path,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
adapter_only: bool = False,
) -> None:
model_file_path.mkdir(parents=True, exist_ok=True)
@ -168,7 +168,7 @@ class TorchtuneCheckpointer:
def _save_hf_format_checkpoint(
self,
model_file_path: Path,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
) -> None:
# the config.json file contains model params needed for state dict conversion
config = json.loads(Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text())
@ -179,7 +179,7 @@ class TorchtuneCheckpointer:
repo_id_path = Path.joinpath(self._checkpoint_dir.parent, REPO_ID_FNAME).with_suffix(".json")
self.repo_id = None
if repo_id_path.exists():
with open(repo_id_path, "r") as json_file:
with open(repo_id_path) as json_file:
data = json.load(json_file)
self.repo_id = data.get("repo_id")

View file

@ -10,7 +10,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Callable, Dict
from collections.abc import Callable
import torch
from pydantic import BaseModel
@ -35,7 +35,7 @@ class ModelConfig(BaseModel):
checkpoint_type: str
MODEL_CONFIGS: Dict[str, ModelConfig] = {
MODEL_CONFIGS: dict[str, ModelConfig] = {
"Llama3.2-3B-Instruct": ModelConfig(
model_definition=lora_llama3_2_3b,
tokenizer_type=llama3_tokenizer,
@ -48,7 +48,7 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
),
}
DATA_FORMATS: Dict[str, Transform] = {
DATA_FORMATS: dict[str, Transform] = {
"instruct": InputOutputToMessages,
"dialog": ShareGPTToMessages,
}