chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.distribution.datatypes import Api
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
async def get_provider_impl(
config: TorchtunePostTrainingConfig,
deps: Dict[Api, Any],
deps: dict[Api, Any],
):
from .post_training import TorchtunePostTrainingImpl

View file

@ -8,7 +8,7 @@ import json
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List
from typing import Any
import torch
from safetensors.torch import save_file
@ -34,7 +34,7 @@ class TorchtuneCheckpointer:
model_id: str,
training_algorithm: str,
checkpoint_dir: str,
checkpoint_files: List[str],
checkpoint_files: list[str],
output_dir: str,
model_type: str,
):
@ -54,11 +54,11 @@ class TorchtuneCheckpointer:
# get ckpt paths
self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file)
def load_checkpoint(self) -> Dict[str, Any]:
def load_checkpoint(self) -> dict[str, Any]:
"""
Load Meta checkpoint from file. Currently only loading from a single file is supported.
"""
state_dict: Dict[str, Any] = {}
state_dict: dict[str, Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
@ -82,7 +82,7 @@ class TorchtuneCheckpointer:
def save_checkpoint(
self,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
epoch: int,
adapter_only: bool = False,
checkpoint_format: str | None = None,
@ -100,7 +100,7 @@ class TorchtuneCheckpointer:
def _save_meta_format_checkpoint(
self,
model_file_path: Path,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
adapter_only: bool = False,
) -> None:
model_file_path.mkdir(parents=True, exist_ok=True)
@ -168,7 +168,7 @@ class TorchtuneCheckpointer:
def _save_hf_format_checkpoint(
self,
model_file_path: Path,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
) -> None:
# the config.json file contains model params needed for state dict conversion
config = json.loads(Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text())
@ -179,7 +179,7 @@ class TorchtuneCheckpointer:
repo_id_path = Path.joinpath(self._checkpoint_dir.parent, REPO_ID_FNAME).with_suffix(".json")
self.repo_id = None
if repo_id_path.exists():
with open(repo_id_path, "r") as json_file:
with open(repo_id_path) as json_file:
data = json.load(json_file)
self.repo_id = data.get("repo_id")

View file

@ -10,7 +10,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Callable, Dict
from collections.abc import Callable
import torch
from pydantic import BaseModel
@ -35,7 +35,7 @@ class ModelConfig(BaseModel):
checkpoint_type: str
MODEL_CONFIGS: Dict[str, ModelConfig] = {
MODEL_CONFIGS: dict[str, ModelConfig] = {
"Llama3.2-3B-Instruct": ModelConfig(
model_definition=lora_llama3_2_3b,
tokenizer_type=llama3_tokenizer,
@ -48,7 +48,7 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
),
}
DATA_FORMATS: Dict[str, Transform] = {
DATA_FORMATS: dict[str, Transform] = {
"instruct": InputOutputToMessages,
"dialog": ShareGPTToMessages,
}

View file

@ -4,17 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Literal, Optional
from typing import Any, Literal
from pydantic import BaseModel
class TorchtunePostTrainingConfig(BaseModel):
torch_seed: Optional[int] = None
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"
torch_seed: int | None = None
checkpoint_format: Literal["meta", "huggingface"] | None = "meta"
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {
"checkpoint_format": "meta",
}

View file

@ -11,7 +11,8 @@
# LICENSE file in the root directory of this source tree.
import json
from typing import Any, Mapping
from collections.abc import Mapping
from typing import Any
from llama_stack.providers.utils.common.data_schema_validator import ColumnName

View file

@ -10,7 +10,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Mapping
from collections.abc import Mapping
from typing import Any
import numpy as np
from torch.utils.data import Dataset
@ -27,7 +28,7 @@ from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapte
class SFTDataset(Dataset):
def __init__(
self,
rows: List[Dict[str, Any]],
rows: list[dict[str, Any]],
message_transform: Transform,
model_transform: Transform,
dataset_type: str,
@ -40,11 +41,11 @@ class SFTDataset(Dataset):
def __len__(self):
return len(self._rows)
def __getitem__(self, index: int) -> Dict[str, Any]:
def __getitem__(self, index: int) -> dict[str, Any]:
sample = self._rows[index]
return self._prepare_sample(sample)
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
def _prepare_sample(self, sample: Mapping[str, Any]) -> dict[str, Any]:
if self._dataset_type == "instruct":
sample = llama_stack_instruct_to_torchtune_instruct(sample)
elif self._dataset_type == "dialog":

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
@ -64,7 +64,7 @@ class TorchtunePostTrainingImpl:
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.RESOURCES_STATS.value,
name=TrainingArtifactType.RESOURCES_STATS.value,
@ -75,11 +75,11 @@ class TorchtunePostTrainingImpl:
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig],
checkpoint_dir: str | None,
algorithm_config: AlgorithmConfig | None,
) -> PostTrainingJob:
if isinstance(algorithm_config, LoraFinetuningConfig):
@ -121,8 +121,8 @@ class TorchtunePostTrainingImpl:
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob: ...
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
@ -144,7 +144,7 @@ class TorchtunePostTrainingImpl:
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
job = self._scheduler.get_job(job_uuid)
match job.status:
@ -175,6 +175,6 @@ class TorchtunePostTrainingImpl:
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))

View file

@ -11,7 +11,7 @@ import time
from datetime import datetime, timezone
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any
import torch
from torch import nn
@ -80,10 +80,10 @@ class LoraFinetuningSingleDevice:
config: TorchtunePostTrainingConfig,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
checkpoint_dir: str | None,
algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
datasetio_api: DatasetIO,
datasets_api: Datasets,
@ -156,7 +156,7 @@ class LoraFinetuningSingleDevice:
self.datasets_api = datasets_api
async def load_checkpoint(self):
def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
def get_checkpoint_files(checkpoint_dir: str) -> list[str]:
try:
# List all files in the given directory
files = os.listdir(checkpoint_dir)
@ -250,8 +250,8 @@ class LoraFinetuningSingleDevice:
self,
enable_activation_checkpointing: bool,
enable_activation_offloading: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
base_model_state_dict: dict[str, Any],
lora_weights_state_dict: dict[str, Any] | None = None,
) -> nn.Module:
self._lora_rank = self.algorithm_config.rank
self._lora_alpha = self.algorithm_config.alpha
@ -335,7 +335,7 @@ class LoraFinetuningSingleDevice:
tokenizer: Llama3Tokenizer,
shuffle: bool,
batch_size: int,
) -> Tuple[DistributedSampler, DataLoader]:
) -> tuple[DistributedSampler, DataLoader]:
async def fetch_rows(dataset_id: str):
return await self.datasetio_api.iterrows(
dataset_id=dataset_id,
@ -430,7 +430,7 @@ class LoraFinetuningSingleDevice:
checkpoint_format=self._checkpoint_format,
)
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
async def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")
# run model
@ -452,7 +452,7 @@ class LoraFinetuningSingleDevice:
return loss
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
async def train(self) -> tuple[dict[str, Any], list[Checkpoint]]:
"""
The core training loop.
"""
@ -464,7 +464,7 @@ class LoraFinetuningSingleDevice:
# training artifacts
checkpoints = []
memory_stats: Dict[str, Any] = {}
memory_stats: dict[str, Any] = {}
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
@ -565,7 +565,7 @@ class LoraFinetuningSingleDevice:
return (memory_stats, checkpoints)
async def validation(self) -> Tuple[float, float]:
async def validation(self) -> tuple[float, float]:
total_loss = 0.0
total_tokens = 0
log.info("Starting validation...")