mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -13,7 +13,6 @@ from dataclasses import dataclass
|
|||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
@ -102,7 +101,7 @@ class DownloadTask:
|
|||
output_file: str
|
||||
total_size: int = 0
|
||||
downloaded_size: int = 0
|
||||
task_id: Optional[int] = None
|
||||
task_id: int | None = None
|
||||
retries: int = 0
|
||||
max_retries: int = 3
|
||||
|
||||
|
@ -262,7 +261,7 @@ class ParallelDownloader:
|
|||
self.progress.update(task.task_id, description=f"[red]Failed: {task.output_file}[/red]")
|
||||
raise DownloadError(f"Download failed for {task.output_file}: {str(e)}") from e
|
||||
|
||||
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
|
||||
def has_disk_space(self, tasks: list[DownloadTask]) -> bool:
|
||||
try:
|
||||
total_remaining_size = sum(task.total_size - task.downloaded_size for task in tasks)
|
||||
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
|
||||
|
@ -282,7 +281,7 @@ class ParallelDownloader:
|
|||
except Exception as e:
|
||||
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
|
||||
|
||||
async def download_all(self, tasks: List[DownloadTask]) -> None:
|
||||
async def download_all(self, tasks: list[DownloadTask]) -> None:
|
||||
if not tasks:
|
||||
raise ValueError("No download tasks provided")
|
||||
|
||||
|
@ -391,20 +390,20 @@ def _meta_download(
|
|||
|
||||
class ModelEntry(BaseModel):
|
||||
model_id: str
|
||||
files: Dict[str, str]
|
||||
files: dict[str, str]
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class Manifest(BaseModel):
|
||||
models: List[ModelEntry]
|
||||
models: list[ModelEntry]
|
||||
expires_on: datetime
|
||||
|
||||
|
||||
def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
with open(manifest_file, "r") as f:
|
||||
with open(manifest_file) as f:
|
||||
d = json.load(f)
|
||||
manifest = Manifest(**d)
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
@ -22,7 +22,7 @@ class PromptGuardModel(BaseModel):
|
|||
max_seq_length: int = 512
|
||||
is_instruct_model: bool = False
|
||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||
arch_args: Dict[str, Any] = Field(default_factory=dict)
|
||||
arch_args: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def descriptor(self) -> str:
|
||||
return self.model_id
|
||||
|
@ -44,11 +44,11 @@ def prompt_guard_model_skus():
|
|||
]
|
||||
|
||||
|
||||
def prompt_guard_model_sku_map() -> Dict[str, Any]:
|
||||
def prompt_guard_model_sku_map() -> dict[str, Any]:
|
||||
return {model.model_id: model for model in prompt_guard_model_skus()}
|
||||
|
||||
|
||||
def prompt_guard_download_info_map() -> Dict[str, LlamaDownloadInfo]:
|
||||
def prompt_guard_download_info_map() -> dict[str, LlamaDownloadInfo]:
|
||||
return {
|
||||
model.model_id: LlamaDownloadInfo(
|
||||
folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id,
|
||||
|
|
|
@ -13,7 +13,6 @@ import sys
|
|||
import textwrap
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import yaml
|
||||
from prompt_toolkit import prompt
|
||||
|
@ -46,14 +45,14 @@ from llama_stack.providers.datatypes import Api
|
|||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def available_templates_specs() -> Dict[str, BuildConfig]:
|
||||
@lru_cache
|
||||
def available_templates_specs() -> dict[str, BuildConfig]:
|
||||
import yaml
|
||||
|
||||
template_specs = {}
|
||||
for p in TEMPLATES_PATH.rglob("*build.yaml"):
|
||||
template_name = p.parent.name
|
||||
with open(p, "r") as f:
|
||||
with open(p) as f:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
template_specs[template_name] = build_config
|
||||
return template_specs
|
||||
|
@ -178,7 +177,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
if not available_providers:
|
||||
continue
|
||||
api_provider = prompt(
|
||||
"> Enter provider for API {}: ".format(api.value),
|
||||
f"> Enter provider for API {api.value}: ",
|
||||
completer=WordCompleter(available_providers),
|
||||
complete_while_typing=True,
|
||||
validator=Validator.from_callable(
|
||||
|
@ -201,7 +200,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
|
||||
build_config = BuildConfig(image_type=image_type, distribution_spec=distribution_spec)
|
||||
else:
|
||||
with open(args.config, "r") as f:
|
||||
with open(args.config) as f:
|
||||
try:
|
||||
build_config = BuildConfig(**yaml.safe_load(f))
|
||||
except Exception as e:
|
||||
|
@ -332,9 +331,9 @@ def _generate_run_config(
|
|||
|
||||
def _run_stack_build_command_from_build_config(
|
||||
build_config: BuildConfig,
|
||||
image_name: Optional[str] = None,
|
||||
template_name: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
image_name: str | None = None,
|
||||
template_name: str | None = None,
|
||||
config_path: str | None = None,
|
||||
) -> str:
|
||||
image_name = image_name or build_config.image_name
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Iterable
|
||||
from collections.abc import Iterable
|
||||
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
|
|
@ -9,7 +9,6 @@ import hashlib
|
|||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
|
@ -21,7 +20,7 @@ from llama_stack.cli.subcommand import Subcommand
|
|||
class VerificationResult:
|
||||
filename: str
|
||||
expected_hash: str
|
||||
actual_hash: Optional[str]
|
||||
actual_hash: str | None
|
||||
exists: bool
|
||||
matches: bool
|
||||
|
||||
|
@ -60,9 +59,9 @@ def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
|
|||
return md5_hash.hexdigest()
|
||||
|
||||
|
||||
def load_checksums(checklist_path: Path) -> Dict[str, str]:
|
||||
def load_checksums(checklist_path: Path) -> dict[str, str]:
|
||||
checksums = {}
|
||||
with open(checklist_path, "r") as f:
|
||||
with open(checklist_path) as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
md5sum, filepath = line.strip().split(" ", 1)
|
||||
|
@ -72,7 +71,7 @@ def load_checksums(checklist_path: Path) -> Dict[str, str]:
|
|||
return checksums
|
||||
|
||||
|
||||
def verify_files(model_dir: Path, checksums: Dict[str, str], console: Console) -> List[VerificationResult]:
|
||||
def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) -> list[VerificationResult]:
|
||||
results = []
|
||||
|
||||
with Progress(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue