chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import importlib
import inspect
from typing import Any, Dict, List, Set, Tuple
from typing import Any
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
@ -58,7 +58,7 @@ class InvalidProviderError(Exception):
pass
def api_protocol_map() -> Dict[Api, Any]:
def api_protocol_map() -> dict[Api, Any]:
return {
Api.providers: ProvidersAPI,
Api.agents: Agents,
@ -83,7 +83,7 @@ def api_protocol_map() -> Dict[Api, Any]:
}
def additional_protocols_map() -> Dict[Api, Any]:
def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
@ -104,14 +104,14 @@ class ProviderWithSpec(Provider):
spec: ProviderSpec
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
ProviderRegistry = dict[Api, dict[str, ProviderSpec]]
async def resolve_impls(
run_config: StackRunConfig,
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
) -> Dict[Api, Any]:
) -> dict[Api, Any]:
"""
Resolves provider implementations by:
1. Validating and organizing providers.
@ -136,7 +136,7 @@ async def resolve_impls(
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]:
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
"""Generates specifications for automatically routed APIs."""
specs = {}
for info in builtin_automatically_routed_apis():
@ -178,10 +178,10 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
def validate_and_prepare_providers(
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api]
) -> Dict[str, Dict[str, ProviderWithSpec]]:
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: set[Api], router_apis: set[Api]
) -> dict[str, dict[str, ProviderWithSpec]]:
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
providers_with_specs: dict[str, dict[str, ProviderWithSpec]] = {}
for api_str, providers in run_config.providers.items():
api = Api(api_str)
@ -222,10 +222,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
def sort_providers_by_deps(
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig
) -> List[Tuple[str, ProviderWithSpec]]:
providers_with_specs: dict[str, dict[str, ProviderWithSpec]], run_config: StackRunConfig
) -> list[tuple[str, ProviderWithSpec]]:
"""Sorts providers based on their dependencies."""
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
sorted_providers: list[tuple[str, ProviderWithSpec]] = topological_sort(
{k: list(v.values()) for k, v in providers_with_specs.items()}
)
@ -236,11 +236,11 @@ def sort_providers_by_deps(
async def instantiate_providers(
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry
) -> Dict:
sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry
) -> dict:
"""Instantiates providers asynchronously while managing dependencies."""
impls: Dict[Api, Any] = {}
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
impls: dict[Api, Any] = {}
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies}
for a in provider.spec.optional_api_dependencies:
@ -263,9 +263,9 @@ async def instantiate_providers(
def topological_sort(
providers_with_specs: Dict[str, List[ProviderWithSpec]],
) -> List[Tuple[str, ProviderWithSpec]]:
def dfs(kv, visited: Set[str], stack: List[str]):
providers_with_specs: dict[str, list[ProviderWithSpec]],
) -> list[tuple[str, ProviderWithSpec]]:
def dfs(kv, visited: set[str], stack: list[str]):
api_str, providers = kv
visited.add(api_str)
@ -280,8 +280,8 @@ def topological_sort(
stack.append(api_str)
visited: Set[str] = set()
stack: List[str] = []
visited: set[str] = set()
stack: list[str] = []
for api_str, providers in providers_with_specs.items():
if api_str not in visited:
@ -298,8 +298,8 @@ def topological_sort(
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider: ProviderWithSpec,
deps: Dict[Api, Any],
inner_impls: Dict[str, Any],
deps: dict[Api, Any],
inner_impls: dict[str, Any],
dist_registry: DistributionRegistry,
):
protocols = api_protocol_map()
@ -391,8 +391,8 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
async def resolve_remote_stack_impls(
config: RemoteProviderConfig,
apis: List[str],
) -> Dict[Api, Any]:
apis: list[str],
) -> dict[Api, Any]:
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()