forked from phoenix-oss/llama-stack-mirror
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -11,7 +11,7 @@ import time
|
|||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -80,10 +80,10 @@ class LoraFinetuningSingleDevice:
|
|||
config: TorchtunePostTrainingConfig,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
checkpoint_dir: str | None,
|
||||
algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
|
@ -156,7 +156,7 @@ class LoraFinetuningSingleDevice:
|
|||
self.datasets_api = datasets_api
|
||||
|
||||
async def load_checkpoint(self):
|
||||
def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
|
||||
def get_checkpoint_files(checkpoint_dir: str) -> list[str]:
|
||||
try:
|
||||
# List all files in the given directory
|
||||
files = os.listdir(checkpoint_dir)
|
||||
|
@ -250,8 +250,8 @@ class LoraFinetuningSingleDevice:
|
|||
self,
|
||||
enable_activation_checkpointing: bool,
|
||||
enable_activation_offloading: bool,
|
||||
base_model_state_dict: Dict[str, Any],
|
||||
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
|
||||
base_model_state_dict: dict[str, Any],
|
||||
lora_weights_state_dict: dict[str, Any] | None = None,
|
||||
) -> nn.Module:
|
||||
self._lora_rank = self.algorithm_config.rank
|
||||
self._lora_alpha = self.algorithm_config.alpha
|
||||
|
@ -335,7 +335,7 @@ class LoraFinetuningSingleDevice:
|
|||
tokenizer: Llama3Tokenizer,
|
||||
shuffle: bool,
|
||||
batch_size: int,
|
||||
) -> Tuple[DistributedSampler, DataLoader]:
|
||||
) -> tuple[DistributedSampler, DataLoader]:
|
||||
async def fetch_rows(dataset_id: str):
|
||||
return await self.datasetio_api.iterrows(
|
||||
dataset_id=dataset_id,
|
||||
|
@ -430,7 +430,7 @@ class LoraFinetuningSingleDevice:
|
|||
checkpoint_format=self._checkpoint_format,
|
||||
)
|
||||
|
||||
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
async def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
# Shape [b, s], needed for the loss not the model
|
||||
labels = batch.pop("labels")
|
||||
# run model
|
||||
|
@ -452,7 +452,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
return loss
|
||||
|
||||
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
|
||||
async def train(self) -> tuple[dict[str, Any], list[Checkpoint]]:
|
||||
"""
|
||||
The core training loop.
|
||||
"""
|
||||
|
@ -464,7 +464,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
# training artifacts
|
||||
checkpoints = []
|
||||
memory_stats: Dict[str, Any] = {}
|
||||
memory_stats: dict[str, Any] = {}
|
||||
|
||||
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||
|
@ -565,7 +565,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
return (memory_stats, checkpoints)
|
||||
|
||||
async def validation(self) -> Tuple[float, float]:
|
||||
async def validation(self) -> tuple[float, float]:
|
||||
total_loss = 0.0
|
||||
total_tokens = 0
|
||||
log.info("Starting validation...")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue