forked from phoenix-oss/llama-stack-mirror
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -8,7 +8,7 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
@ -34,7 +34,7 @@ class TorchtuneCheckpointer:
|
|||
model_id: str,
|
||||
training_algorithm: str,
|
||||
checkpoint_dir: str,
|
||||
checkpoint_files: List[str],
|
||||
checkpoint_files: list[str],
|
||||
output_dir: str,
|
||||
model_type: str,
|
||||
):
|
||||
|
@ -54,11 +54,11 @@ class TorchtuneCheckpointer:
|
|||
# get ckpt paths
|
||||
self._checkpoint_path = Path.joinpath(self._checkpoint_dir, self._checkpoint_file)
|
||||
|
||||
def load_checkpoint(self) -> Dict[str, Any]:
|
||||
def load_checkpoint(self) -> dict[str, Any]:
|
||||
"""
|
||||
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
||||
"""
|
||||
state_dict: Dict[str, Any] = {}
|
||||
state_dict: dict[str, Any] = {}
|
||||
model_state_dict = safe_torch_load(self._checkpoint_path)
|
||||
if self._model_type == ModelType.LLAMA3_VISION:
|
||||
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||
|
@ -82,7 +82,7 @@ class TorchtuneCheckpointer:
|
|||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
epoch: int,
|
||||
adapter_only: bool = False,
|
||||
checkpoint_format: str | None = None,
|
||||
|
@ -100,7 +100,7 @@ class TorchtuneCheckpointer:
|
|||
def _save_meta_format_checkpoint(
|
||||
self,
|
||||
model_file_path: Path,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
adapter_only: bool = False,
|
||||
) -> None:
|
||||
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -168,7 +168,7 @@ class TorchtuneCheckpointer:
|
|||
def _save_hf_format_checkpoint(
|
||||
self,
|
||||
model_file_path: Path,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
) -> None:
|
||||
# the config.json file contains model params needed for state dict conversion
|
||||
config = json.loads(Path.joinpath(self._checkpoint_dir.parent, "config.json").read_text())
|
||||
|
@ -179,7 +179,7 @@ class TorchtuneCheckpointer:
|
|||
repo_id_path = Path.joinpath(self._checkpoint_dir.parent, REPO_ID_FNAME).with_suffix(".json")
|
||||
self.repo_id = None
|
||||
if repo_id_path.exists():
|
||||
with open(repo_id_path, "r") as json_file:
|
||||
with open(repo_id_path) as json_file:
|
||||
data = json.load(json_file)
|
||||
self.repo_id = data.get("repo_id")
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Callable, Dict
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
@ -35,7 +35,7 @@ class ModelConfig(BaseModel):
|
|||
checkpoint_type: str
|
||||
|
||||
|
||||
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||
MODEL_CONFIGS: dict[str, ModelConfig] = {
|
||||
"Llama3.2-3B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_2_3b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
|
@ -48,7 +48,7 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
|||
),
|
||||
}
|
||||
|
||||
DATA_FORMATS: Dict[str, Transform] = {
|
||||
DATA_FORMATS: dict[str, Transform] = {
|
||||
"instruct": InputOutputToMessages,
|
||||
"dialog": ShareGPTToMessages,
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue