mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-19 03:10:03 +00:00
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
|
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
|
|||
|
||||
async def get_provider_impl(
|
||||
config: TorchtunePostTrainingConfig,
|
||||
deps: Dict[Api, Any],
|
||||
deps: dict[Api, Any],
|
||||
):
|
||||
from .post_training import TorchtunePostTrainingImpl
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
@ -34,7 +34,7 @@ class TorchtuneCheckpointer:
|
|||
model_id: str,
|
||||
training_algorithm: str,
|
||||
checkpoint_dir: str,
|
||||
checkpoint_files: List[str],
|
||||
checkpoint_files: list[str],
|
||||
output_dir: str,
|
||||
model_type: str,
|
||||
):
|
||||
|
@ -54,11 +54,11 @@ class TorchtuneCheckpointer:
|
|||
# get ckpt paths
|
||||
self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file)
|
||||
|
||||
def load_checkpoint(self) -> Dict[str, Any]:
|
||||
def load_checkpoint(self) -> dict[str, Any]:
|
||||
"""
|
||||
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
||||
"""
|
||||
state_dict: Dict[str, Any] = {}
|
||||
state_dict: dict[str, Any] = {}
|
||||
model_state_dict = safe_torch_load(self._checkpoint_path)
|
||||
if self._model_type == ModelType.LLAMA3_VISION:
|
||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||
|
@ -82,7 +82,7 @@ class TorchtuneCheckpointer:
|
|||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
epoch: int,
|
||||
adapter_only: bool = False,
|
||||
checkpoint_format: str | None = None,
|
||||
|
@ -100,7 +100,7 @@ class TorchtuneCheckpointer:
|
|||
def _save_meta_format_checkpoint(
|
||||
self,
|
||||
model_file_path: Path,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
adapter_only: bool = False,
|
||||
) -> None:
|
||||
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -168,7 +168,7 @@ class TorchtuneCheckpointer:
|
|||
def _save_hf_format_checkpoint(
|
||||
self,
|
||||
model_file_path: Path,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
) -> None:
|
||||
# the config.json file contains model params needed for state dict conversion
|
||||
config = json.loads(Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text())
|
||||
|
@ -179,7 +179,7 @@ class TorchtuneCheckpointer:
|
|||
repo_id_path = Path.joinpath(self._checkpoint_dir.parent, REPO_ID_FNAME).with_suffix(".json")
|
||||
self.repo_id = None
|
||||
if repo_id_path.exists():
|
||||
with open(repo_id_path, "r") as json_file:
|
||||
with open(repo_id_path) as json_file:
|
||||
data = json.load(json_file)
|
||||
self.repo_id = data.get("repo_id")
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Callable, Dict
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
@ -35,7 +35,7 @@ class ModelConfig(BaseModel):
|
|||
checkpoint_type: str
|
||||
|
||||
|
||||
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||
MODEL_CONFIGS: dict[str, ModelConfig] = {
|
||||
"Llama3.2-3B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_2_3b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
|
@ -48,7 +48,7 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
|||
),
|
||||
}
|
||||
|
||||
DATA_FORMATS: Dict[str, Transform] = {
|
||||
DATA_FORMATS: dict[str, Transform] = {
|
||||
"instruct": InputOutputToMessages,
|
||||
"dialog": ShareGPTToMessages,
|
||||
}
|
||||
|
|
|
@ -4,17 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TorchtunePostTrainingConfig(BaseModel):
|
||||
torch_seed: Optional[int] = None
|
||||
checkpoint_format: Optional[Literal["meta", "huggingface"]] = "meta"
|
||||
torch_seed: int | None = None
|
||||
checkpoint_format: Literal["meta", "huggingface"] | None = "meta"
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"checkpoint_format": "meta",
|
||||
}
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any, Mapping
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
|
||||
|
|
|
@ -10,7 +10,8 @@
|
|||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Mapping
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
@ -27,7 +28,7 @@ from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapte
|
|||
class SFTDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
rows: List[Dict[str, Any]],
|
||||
rows: list[dict[str, Any]],
|
||||
message_transform: Transform,
|
||||
model_transform: Transform,
|
||||
dataset_type: str,
|
||||
|
@ -40,11 +41,11 @@ class SFTDataset(Dataset):
|
|||
def __len__(self):
|
||||
return len(self._rows)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||
def __getitem__(self, index: int) -> dict[str, Any]:
|
||||
sample = self._rows[index]
|
||||
return self._prepare_sample(sample)
|
||||
|
||||
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
|
||||
def _prepare_sample(self, sample: Mapping[str, Any]) -> dict[str, Any]:
|
||||
if self._dataset_type == "instruct":
|
||||
sample = llama_stack_instruct_to_torchtune_instruct(sample)
|
||||
elif self._dataset_type == "dialog":
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
@ -64,7 +64,7 @@ class TorchtunePostTrainingImpl:
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def _resources_stats_to_artifact(resources_stats: Dict[str, Any]) -> JobArtifact:
|
||||
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
|
||||
return JobArtifact(
|
||||
type=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
name=TrainingArtifactType.RESOURCES_STATS.value,
|
||||
|
@ -75,11 +75,11 @@ class TorchtunePostTrainingImpl:
|
|||
self,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
checkpoint_dir: str | None,
|
||||
algorithm_config: AlgorithmConfig | None,
|
||||
) -> PostTrainingJob:
|
||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
|
||||
|
@ -121,8 +121,8 @@ class TorchtunePostTrainingImpl:
|
|||
finetuned_model: str,
|
||||
algorithm_config: DPOAlignmentConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
|
@ -144,7 +144,7 @@ class TorchtunePostTrainingImpl:
|
|||
return data[0] if data else None
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]:
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
|
||||
match job.status:
|
||||
|
@ -175,6 +175,6 @@ class TorchtunePostTrainingImpl:
|
|||
self._scheduler.cancel(job_uuid)
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
|
||||
job = self._scheduler.get_job(job_uuid)
|
||||
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
|
||||
|
|
|
@ -11,7 +11,7 @@ import time
|
|||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -80,10 +80,10 @@ class LoraFinetuningSingleDevice:
|
|||
config: TorchtunePostTrainingConfig,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
checkpoint_dir: str | None,
|
||||
algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
|
@ -156,7 +156,7 @@ class LoraFinetuningSingleDevice:
|
|||
self.datasets_api = datasets_api
|
||||
|
||||
async def load_checkpoint(self):
|
||||
def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
|
||||
def get_checkpoint_files(checkpoint_dir: str) -> list[str]:
|
||||
try:
|
||||
# List all files in the given directory
|
||||
files = os.listdir(checkpoint_dir)
|
||||
|
@ -250,8 +250,8 @@ class LoraFinetuningSingleDevice:
|
|||
self,
|
||||
enable_activation_checkpointing: bool,
|
||||
enable_activation_offloading: bool,
|
||||
base_model_state_dict: Dict[str, Any],
|
||||
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
|
||||
base_model_state_dict: dict[str, Any],
|
||||
lora_weights_state_dict: dict[str, Any] | None = None,
|
||||
) -> nn.Module:
|
||||
self._lora_rank = self.algorithm_config.rank
|
||||
self._lora_alpha = self.algorithm_config.alpha
|
||||
|
@ -335,7 +335,7 @@ class LoraFinetuningSingleDevice:
|
|||
tokenizer: Llama3Tokenizer,
|
||||
shuffle: bool,
|
||||
batch_size: int,
|
||||
) -> Tuple[DistributedSampler, DataLoader]:
|
||||
) -> tuple[DistributedSampler, DataLoader]:
|
||||
async def fetch_rows(dataset_id: str):
|
||||
return await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
|
@ -430,7 +430,7 @@ class LoraFinetuningSingleDevice:
|
|||
checkpoint_format=self._checkpoint_format,
|
||||
)
|
||||
|
||||
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
async def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
# Shape [b, s], needed for the loss not the model
|
||||
labels = batch.pop("labels")
|
||||
# run model
|
||||
|
@ -452,7 +452,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
return loss
|
||||
|
||||
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
|
||||
async def train(self) -> tuple[dict[str, Any], list[Checkpoint]]:
|
||||
"""
|
||||
The core training loop.
|
||||
"""
|
||||
|
@ -464,7 +464,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
# training artifacts
|
||||
checkpoints = []
|
||||
memory_stats: Dict[str, Any] = {}
|
||||
memory_stats: dict[str, Any] = {}
|
||||
|
||||
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||
|
@ -565,7 +565,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
return (memory_stats, checkpoints)
|
||||
|
||||
async def validation(self) -> Tuple[float, float]:
|
||||
async def validation(self) -> tuple[float, float]:
|
||||
total_loss = 0.0
|
||||
total_tokens = 0
|
||||
log.info("Starting validation...")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue