chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -13,7 +13,6 @@ from dataclasses import dataclass
from datetime import datetime, timezone
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional
import httpx
from pydantic import BaseModel, ConfigDict
@ -102,7 +101,7 @@ class DownloadTask:
output_file: str
total_size: int = 0
downloaded_size: int = 0
task_id: Optional[int] = None
task_id: int | None = None
retries: int = 0
max_retries: int = 3
@ -262,7 +261,7 @@ class ParallelDownloader:
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
def has_disk_space(self, tasks: list[DownloadTask]) -> bool:
try:
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
@ -282,7 +281,7 @@ class ParallelDownloader:
except Exception as e:
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
async def download_all(self, tasks: List[DownloadTask]) -> None:
async def download_all(self, tasks: list[DownloadTask]) -> None:
if not tasks:
raise ValueError("No download tasks provided")
@ -391,20 +390,20 @@ def _meta_download(
class ModelEntry(BaseModel):
model_id: str
files: Dict[str, str]
files: dict[str, str]
model_config = ConfigDict(protected_namespaces=())
class Manifest(BaseModel):
models: List[ModelEntry]
models: list[ModelEntry]
expires_on: datetime
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
from llama_stack.distribution.utils.model_utils import model_local_dir
with open(manifest_file, "r") as f:
with open(manifest_file) as f:
d = json.load(f)
manifest = Manifest(**d)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
@ -22,7 +22,7 @@ class PromptGuardModel(BaseModel):
max_seq_length: int = 512
is_instruct_model: bool = False
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
arch_args: Dict[str, Any] = Field(default_factory=dict)
arch_args: dict[str, Any] = Field(default_factory=dict)
def descriptor(self) -> str:
return self.model_id
@ -44,11 +44,11 @@ def prompt_guard_model_skus():
]
def prompt_guard_model_sku_map() -> Dict[str, Any]:
def prompt_guard_model_sku_map() -> dict[str, Any]:
return {model.model_id: model for model in prompt_guard_model_skus()}
def prompt_guard_download_info_map() -> Dict[str, LlamaDownloadInfo]:
def prompt_guard_download_info_map() -> dict[str, LlamaDownloadInfo]:
return {
model.model_id: LlamaDownloadInfo(
folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id,

View file

@ -13,7 +13,6 @@ import sys
import textwrap
from functools import lru_cache
from pathlib import Path
from typing import Dict, Optional
import yaml
from prompt_toolkit import prompt
@ -46,14 +45,14 @@ from llama_stack.providers.datatypes import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
@lru_cache()
def available_templates_specs() -> Dict[str, BuildConfig]:
@lru_cache
def available_templates_specs() -> dict[str, BuildConfig]:
import yaml
template_specs = {}
for p in TEMPLATES_PATH.rglob("*build.yaml"):
template_name = p.parent.name
with open(p, "r") as f:
with open(p) as f:
build_config = BuildConfig(**yaml.safe_load(f))
template_specs[template_name] = build_config
return template_specs
@ -178,7 +177,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
if not available_providers:
continue
api_provider = prompt(
"> Enter provider for API {}: ".format(api.value),
f"> Enter provider for API {api.value}: ",
completer=WordCompleter(available_providers),
complete_while_typing=True,
validator=Validator.from_callable(
@ -201,7 +200,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec)
else:
with open(args.config, "r") as f:
with open(args.config) as f:
try:
build_config = BuildConfig(**yaml.safe_load(f))
except Exception as e:
@ -332,9 +331,9 @@ def _generate_run_config(
def _run_stack_build_command_from_build_config(
build_config: BuildConfig,
image_name: Optional[str] = None,
template_name: Optional[str] = None,
config_path: Optional[str] = None,
image_name: str | None = None,
template_name: str | None = None,
config_path: str | None = None,
) -> str:
image_name = image_name or build_config.image_name
if build_config.image_type == LlamaStackImageType.CONTAINER.value:

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Iterable
from collections.abc import Iterable
from rich.console import Console
from rich.table import Table

View file

@ -9,7 +9,6 @@ import hashlib
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
@ -21,7 +20,7 @@ from llama_stack.cli.subcommand import Subcommand
class VerificationResult:
filename: str
expected_hash: str
actual_hash: Optional[str]
actual_hash: str | None
exists: bool
matches: bool
@ -60,9 +59,9 @@ def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
return md5_hash.hexdigest()
def load_checksums(checklist_path: Path) -> Dict[str, str]:
def load_checksums(checklist_path: Path) -> dict[str, str]:
checksums = {}
with open(checklist_path, "r") as f:
with open(checklist_path) as f:
for line in f:
if line.strip():
md5sum, filepath = line.strip().split(" ", 1)
@ -72,7 +71,7 @@ def load_checksums(checklist_path: Path) -> Dict[str, str]:
return checksums
def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -> List[VerificationResult]:
def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) -> list[VerificationResult]:
results = []
with Progress(