chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from uuid import uuid4
import pytest
@ -37,7 +37,7 @@ def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
return -1
def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> Dict[str, Any]:
def get_boiling_point_with_metadata(liquid_name: str, celcius: bool = True) -> dict[str, Any]:
"""
Returns the boiling point of a liquid in Celcius or Fahrenheit

View file

@ -24,7 +24,7 @@ class RecordableMock:
# Load existing cache if available and not recording
if self.json_path.exists():
try:
with open(self.json_path, "r") as f:
with open(self.json_path) as f:
self.cache = json.load(f)
except Exception as e:
print(f"Error loading cache from {self.json_path}: {e}")

View file

@ -3,7 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
import pytest
@ -77,7 +76,7 @@ class TestPostTraining:
async def test_get_training_jobs(self, post_training_stack):
post_training_impl = post_training_stack
jobs_list = await post_training_impl.get_training_jobs()
assert isinstance(jobs_list, List)
assert isinstance(jobs_list, list)
assert jobs_list[0].job_uuid == "1234"
@pytest.mark.asyncio

View file

@ -20,7 +20,7 @@ class TestCase:
# loading all test cases
if self._jsonblob == {}:
for api in self._apis:
with open(pathlib.Path(__file__).parent / f"{api}.json", "r") as f:
with open(pathlib.Path(__file__).parent / f"{api}.json") as f:
coloned = api.replace("/", ":")
try:
loaded = json.load(f)

View file

@ -18,11 +18,11 @@ from llama_stack.distribution.configure import (
@pytest.fixture
def up_to_date_config():
return yaml.safe_load(
"""
version: {version}
f"""
version: {LLAMA_STACK_RUN_CONFIG_VERSION}
image_name: foo
apis_to_serve: []
built_at: {built_at}
built_at: {datetime.now().isoformat()}
providers:
inference:
- provider_id: provider1
@ -42,16 +42,16 @@ def up_to_date_config():
- provider_id: provider1
provider_type: inline::meta-reference
config: {{}}
""".format(version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat())
"""
)
@pytest.fixture
def old_config():
return yaml.safe_load(
"""
f"""
image_name: foo
built_at: {built_at}
built_at: {datetime.now().isoformat()}
apis_to_serve: []
routing_table:
inference:
@ -82,7 +82,7 @@ def old_config():
telemetry:
provider_type: noop
config: {{}}
""".format(built_at=datetime.now().isoformat())
"""
)

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from unittest.mock import patch
import pytest
@ -23,7 +23,7 @@ class SampleConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"foo": "baz",
}

View file

@ -10,7 +10,7 @@ import logging
import threading
import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any, Dict
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
@ -55,7 +55,7 @@ from llama_stack.providers.remote.inference.vllm.vllm import (
class MockInferenceAdapterWithSleep:
def __init__(self, sleep_time: int, response: Dict[str, Any]):
def __init__(self, sleep_time: int, response: dict[str, Any]):
self.httpd = None
class DelayedRequestHandler(BaseHTTPRequestHandler):

View file

@ -22,7 +22,7 @@ from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
class AsyncMock(MagicMock):
async def __call__(self, *args, **kwargs):
return super(AsyncMock, self).__call__(*args, **kwargs)
return super().__call__(*args, **kwargs)
def _return_model(model):

View file

@ -6,7 +6,7 @@
import inspect
import sys
from typing import Any, Dict, Protocol
from typing import Any, Protocol
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -48,14 +48,14 @@ class SampleConfig(BaseModel):
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"foo": "baz",
}
class SampleImpl:
def __init__(self, config: SampleConfig, deps: Dict[Api, Any], provider_spec: ProviderSpec = None):
def __init__(self, config: SampleConfig, deps: dict[Api, Any], provider_spec: ProviderSpec = None):
self.__provider_id__ = "test_provider"
self.__provider_spec__ = provider_spec
self.__provider_config__ = config

View file

@ -50,7 +50,7 @@ import subprocess
import time
from collections import defaultdict
from pathlib import Path
from typing import Any, DefaultDict, Dict, Set, Tuple
from typing import Any
from tests.verifications.openai_api.fixtures.fixtures import _load_all_verification_configs
@ -106,7 +106,7 @@ def run_tests(provider, keyword=None):
# Check if the JSON file was created
if temp_json_file.exists():
with open(temp_json_file, "r") as f:
with open(temp_json_file) as f:
test_results = json.load(f)
test_results["run_timestamp"] = timestamp
@ -141,7 +141,7 @@ def run_multiple_tests(providers_to_run: list[str], keyword: str | None):
def parse_results(
result_file,
) -> Tuple[DefaultDict[str, DefaultDict[str, Dict[str, bool]]], DefaultDict[str, Set[str]], Set[str], str]:
) -> tuple[defaultdict[str, defaultdict[str, dict[str, bool]]], defaultdict[str, set[str]], set[str], str]:
"""Parse a single test results file.
Returns:
@ -156,13 +156,13 @@ def parse_results(
# Return empty defaultdicts/set matching the type hint
return defaultdict(lambda: defaultdict(dict)), defaultdict(set), set(), ""
with open(result_file, "r") as f:
with open(result_file) as f:
results = json.load(f)
# Initialize results dictionary with specific types
parsed_results: DefaultDict[str, DefaultDict[str, Dict[str, bool]]] = defaultdict(lambda: defaultdict(dict))
providers_in_file: DefaultDict[str, Set[str]] = defaultdict(set)
tests_in_file: Set[str] = set()
parsed_results: defaultdict[str, defaultdict[str, dict[str, bool]]] = defaultdict(lambda: defaultdict(dict))
providers_in_file: defaultdict[str, set[str]] = defaultdict(set)
tests_in_file: set[str] = set()
# Extract provider from filename (e.g., "openai.json" -> "openai")
provider: str = result_file.stem
@ -248,10 +248,10 @@ def parse_results(
def generate_report(
results_dict: Dict[str, Any],
providers: Dict[str, Set[str]],
all_tests: Set[str],
provider_timestamps: Dict[str, str],
results_dict: dict[str, Any],
providers: dict[str, set[str]],
all_tests: set[str],
provider_timestamps: dict[str, str],
output_file=None,
):
"""Generate the markdown report.
@ -277,8 +277,8 @@ def generate_report(
sorted_tests = sorted(all_tests)
# Calculate counts for each base test name
base_test_case_counts: DefaultDict[str, int] = defaultdict(int)
base_test_name_map: Dict[str, str] = {}
base_test_case_counts: defaultdict[str, int] = defaultdict(int)
base_test_name_map: dict[str, str] = {}
for test_name in sorted_tests:
match = re.match(r"^(.*?)( \([^)]+\))?$", test_name)
if match:

View file

@ -18,7 +18,7 @@ def pytest_generate_tests(metafunc):
try:
config_data = _load_all_verification_configs()
except (FileNotFoundError, IOError) as e:
except (OSError, FileNotFoundError) as e:
print(f"ERROR loading verification configs: {e}")
config_data = {"providers": {}}

View file

@ -33,7 +33,7 @@ def _load_all_verification_configs():
for config_path in yaml_files:
provider_name = config_path.stem
try:
with open(config_path, "r") as f:
with open(config_path) as f:
provider_config = yaml.safe_load(f)
if provider_config:
all_provider_configs[provider_name] = provider_config
@ -41,7 +41,7 @@ def _load_all_verification_configs():
# Log warning if possible, or just skip empty files silently
print(f"Warning: Config file {config_path} is empty or invalid.")
except Exception as e:
raise IOError(f"Error loading config file {config_path}: {e}") from e
raise OSError(f"Error loading config file {config_path}: {e}") from e
return {"providers": all_provider_configs}
@ -49,7 +49,7 @@ def _load_all_verification_configs():
def case_id_generator(case):
"""Generate a test ID from the case's 'case_id' field, or use a default."""
case_id = case.get("case_id")
if isinstance(case_id, (str, int)):
if isinstance(case_id, str | int):
return re.sub(r"\\W|^(?=\\d)", "_", str(case_id))
return None
@ -77,7 +77,7 @@ def verification_config():
"""Pytest fixture to provide the loaded verification config."""
try:
return _load_all_verification_configs()
except (FileNotFoundError, IOError) as e:
except (OSError, FileNotFoundError) as e:
pytest.fail(str(e)) # Fail test collection if config loading fails

View file

@ -12,5 +12,5 @@ import yaml
def load_test_cases(name: str):
fixture_dir = Path(__file__).parent / "test_cases"
yaml_path = fixture_dir / f"{name}.yaml"
with open(yaml_path, "r") as f:
with open(yaml_path) as f:
return yaml.safe_load(f)