chore: enable pyupgrade fixes (#1806)

# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-05-01 17:23:50 -04:00 committed by GitHub
parent ffe3d0b2cd
commit 9e6561a1ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
319 changed files with 2843 additions and 3033 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Any
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
@ -14,8 +14,8 @@ logger = get_logger(__name__, category="core")
def check_access(
obj_identifier: str,
obj_attributes: Optional[AccessAttributes],
user_attributes: Optional[Dict[str, Any]] = None,
obj_attributes: AccessAttributes | None,
user_attributes: dict[str, Any] | None = None,
) -> bool:
"""Check if the current user has access to the given object, based on access attributes.

View file

@ -8,7 +8,7 @@ import inspect
import json
from collections.abc import AsyncIterator
from enum import Enum
from typing import Any, Type, Union, get_args, get_origin
from typing import Any, Union, get_args, get_origin
import httpx
from pydantic import BaseModel, parse_obj_as
@ -27,7 +27,7 @@ async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
return impl
def create_api_client_class(protocol) -> Type:
def create_api_client_class(protocol) -> type:
if protocol in _CLIENT_CLASSES:
return _CLIENT_CLASSES[protocol]

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import logging
import textwrap
from typing import Any, Dict
from typing import Any
from llama_stack.distribution.datatypes import (
LLAMA_STACK_RUN_CONFIG_VERSION,
@ -24,7 +24,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
logger = logging.getLogger(__name__)
def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provider) -> Provider:
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
provider_spec = registry[provider.provider_type]
config_type = instantiate_class_type(provider_spec.config_class)
try:
@ -120,8 +120,8 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
def upgrade_from_routing_table(
config_dict: Dict[str, Any],
) -> Dict[str, Any]:
config_dict: dict[str, Any],
) -> dict[str, Any]:
def get_providers(entries):
return [
Provider(
@ -163,7 +163,7 @@ def upgrade_from_routing_table(
return config_dict
def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig:
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
version = config_dict.get("version", None)
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
return StackRunConfig(**config_dict)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional, Union
from typing import Annotated, Any
from pydantic import BaseModel, Field
@ -30,7 +30,7 @@ LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
RoutingKey = Union[str, List[str]]
RoutingKey = str | list[str]
class AccessAttributes(BaseModel):
@ -47,17 +47,17 @@ class AccessAttributes(BaseModel):
"""
# Standard attribute categories - the minimal set we need now
roles: Optional[List[str]] = Field(
roles: list[str] | None = Field(
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
)
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
projects: Optional[List[str]] = Field(
projects: list[str] | None = Field(
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
)
namespaces: Optional[List[str]] = Field(
namespaces: list[str] | None = Field(
default=None, description="Namespace-based access control for resource isolation"
)
@ -106,7 +106,7 @@ class ResourceWithACL(Resource):
# ^ User must have access to the customer-insights project AND have confidential namespace
"""
access_attributes: Optional[AccessAttributes] = None
access_attributes: AccessAttributes | None = None
# Use the extended Resource for all routable objects
@ -142,41 +142,21 @@ class ToolGroupWithACL(ToolGroup, ResourceWithACL):
pass
RoutableObject = Union[
Model,
Shield,
VectorDB,
Dataset,
ScoringFn,
Benchmark,
Tool,
ToolGroup,
]
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
RoutableObjectWithProvider = Annotated[
Union[
ModelWithACL,
ShieldWithACL,
VectorDBWithACL,
DatasetWithACL,
ScoringFnWithACL,
BenchmarkWithACL,
ToolWithACL,
ToolGroupWithACL,
],
ModelWithACL
| ShieldWithACL
| VectorDBWithACL
| DatasetWithACL
| ScoringFnWithACL
| BenchmarkWithACL
| ToolWithACL
| ToolGroupWithACL,
Field(discriminator="type"),
]
RoutedProtocol = Union[
Inference,
Safety,
VectorIO,
DatasetIO,
Scoring,
Eval,
ToolRuntime,
]
RoutedProtocol = Inference | Safety | VectorIO | DatasetIO | Scoring | Eval | ToolRuntime
# Example: /inference, /safety
@ -184,15 +164,15 @@ class AutoRoutedProviderSpec(ProviderSpec):
provider_type: str = "router"
config_class: str = ""
container_image: Optional[str] = None
container_image: str | None = None
routing_table_api: Api
module: str
provider_data_validator: Optional[str] = Field(
provider_data_validator: str | None = Field(
default=None,
)
@property
def pip_packages(self) -> List[str]:
def pip_packages(self) -> list[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
@ -200,20 +180,20 @@ class AutoRoutedProviderSpec(ProviderSpec):
class RoutingTableProviderSpec(ProviderSpec):
provider_type: str = "routing_table"
config_class: str = ""
container_image: Optional[str] = None
container_image: str | None = None
router_api: Api
module: str
pip_packages: List[str] = Field(default_factory=list)
pip_packages: list[str] = Field(default_factory=list)
class DistributionSpec(BaseModel):
description: Optional[str] = Field(
description: str | None = Field(
default="",
description="Description of the distribution",
)
container_image: Optional[str] = None
providers: Dict[str, Union[str, List[str]]] = Field(
container_image: str | None = None
providers: dict[str, str | list[str]] = Field(
default_factory=dict,
description="""
Provider Types for each of the APIs provided by this distribution. If you
@ -225,12 +205,12 @@ in the runtime configuration to help route to the correct provider.""",
class Provider(BaseModel):
provider_id: str
provider_type: str
config: Dict[str, Any]
config: dict[str, Any]
class LoggingConfig(BaseModel):
category_levels: Dict[str, str] = Field(
default_factory=Dict,
category_levels: dict[str, str] = Field(
default_factory=dict,
description="""
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
)
@ -248,7 +228,7 @@ class AuthenticationConfig(BaseModel):
...,
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
)
config: Dict[str, str] = Field(
config: dict[str, str] = Field(
...,
description="Provider-specific configuration",
)
@ -261,15 +241,15 @@ class ServerConfig(BaseModel):
ge=1024,
le=65535,
)
tls_certfile: Optional[str] = Field(
tls_certfile: str | None = Field(
default=None,
description="Path to TLS certificate file for HTTPS",
)
tls_keyfile: Optional[str] = Field(
tls_keyfile: str | None = Field(
default=None,
description="Path to TLS key file for HTTPS",
)
auth: Optional[AuthenticationConfig] = Field(
auth: AuthenticationConfig | None = Field(
default=None,
description="Authentication configuration for the server",
)
@ -285,23 +265,23 @@ Reference to the distribution this package refers to. For unregistered (adhoc) p
this could be just a hash
""",
)
container_image: Optional[str] = Field(
container_image: str | None = Field(
default=None,
description="Reference to the container image if this package refers to a container",
)
apis: List[str] = Field(
apis: list[str] = Field(
default_factory=list,
description="""
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
)
providers: Dict[str, List[Provider]] = Field(
providers: dict[str, list[Provider]] = Field(
description="""
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
can be instantiated multiple times (with different configs) if necessary.
""",
)
metadata_store: Optional[KVStoreConfig] = Field(
metadata_store: KVStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used by the distribution registry. If not specified,
@ -309,22 +289,22 @@ a default SQLite store will be used.""",
)
# registry of "resources" in the distribution
models: List[ModelInput] = Field(default_factory=list)
shields: List[ShieldInput] = Field(default_factory=list)
vector_dbs: List[VectorDBInput] = Field(default_factory=list)
datasets: List[DatasetInput] = Field(default_factory=list)
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list)
vector_dbs: list[VectorDBInput] = Field(default_factory=list)
datasets: list[DatasetInput] = Field(default_factory=list)
scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
benchmarks: list[BenchmarkInput] = Field(default_factory=list)
tool_groups: list[ToolGroupInput] = Field(default_factory=list)
logging: Optional[LoggingConfig] = Field(default=None, description="Configuration for Llama Stack Logging")
logging: LoggingConfig | None = Field(default=None, description="Configuration for Llama Stack Logging")
server: ServerConfig = Field(
default_factory=ServerConfig,
description="Configuration for the HTTP(S) server",
)
external_providers_dir: Optional[str] = Field(
external_providers_dir: str | None = Field(
default=None,
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
)
@ -338,11 +318,11 @@ class BuildConfig(BaseModel):
default="conda",
description="Type of package to build (conda | container | venv)",
)
image_name: Optional[str] = Field(
image_name: str | None = Field(
default=None,
description="Name of the distribution to build",
)
external_providers_dir: Optional[str] = Field(
external_providers_dir: str | None = Field(
default=None,
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
"pip_packages MUST contain the provider package name.",

View file

@ -7,7 +7,7 @@
import glob
import importlib
import os
from typing import Any, Dict, List
from typing import Any
import yaml
from pydantic import BaseModel
@ -24,7 +24,7 @@ from llama_stack.providers.datatypes import (
logger = get_logger(name=__name__, category="core")
def stack_apis() -> List[Api]:
def stack_apis() -> list[Api]:
return list(Api)
@ -33,7 +33,7 @@ class AutoRoutedApiInfo(BaseModel):
router_api: Api
def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
return [
AutoRoutedApiInfo(
routing_table_api=Api.models,
@ -66,12 +66,12 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
]
def providable_apis() -> List[Api]:
def providable_apis() -> list[Api]:
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderSpec:
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
adapter = AdapterSpec(**spec_data["adapter"])
spec = remote_provider_spec(
api=api,
@ -81,7 +81,7 @@ def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderS
return spec
def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
spec = InlineProviderSpec(
api=api,
provider_type=f"inline::{provider_name}",
@ -98,7 +98,7 @@ def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_nam
def get_provider_registry(
config=None,
) -> Dict[Api, Dict[str, ProviderSpec]]:
) -> dict[Api, dict[str, ProviderSpec]]:
"""Get the provider registry, optionally including external providers.
This function loads both built-in providers and external providers from YAML files.
@ -133,7 +133,7 @@ def get_provider_registry(
ValueError: If any provider spec is invalid
"""
ret: Dict[Api, Dict[str, ProviderSpec]] = {}
ret: dict[Api, dict[str, ProviderSpec]] = {}
for api in providable_apis():
name = api.name.lower()
logger.debug(f"Importing module {name}")

View file

@ -12,7 +12,7 @@ import os
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
from typing import Any, TypeVar, Union, get_args, get_origin
import httpx
import yaml
@ -119,8 +119,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
self,
config_path_or_template_name: str,
skip_logger_removal: bool = False,
custom_provider_registry: Optional[ProviderRegistry] = None,
provider_data: Optional[dict[str, Any]] = None,
custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None,
):
super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient(
@ -181,8 +181,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
def __init__(
self,
config_path_or_template_name: str,
custom_provider_registry: Optional[ProviderRegistry] = None,
provider_data: Optional[dict[str, Any]] = None,
custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None,
):
super().__init__()
# when using the library client, we should not log to console since many
@ -371,7 +371,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
return await response.parse()
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
def _convert_body(self, path: str, method: str, body: dict | None = None) -> dict:
if not body:
return {}

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import asyncio
from typing import Any, Dict
from typing import Any
from pydantic import BaseModel
@ -73,14 +73,14 @@ class ProviderImpl(Providers):
raise ValueError(f"Provider {provider_id} not found")
async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]:
async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]:
"""Get health status for all providers.
Returns:
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
Each API maps to a dictionary of provider IDs to their health responses.
"""
providers_health: Dict[str, Dict[str, HealthResponse]] = {}
providers_health: dict[str, dict[str, HealthResponse]] = {}
timeout = 1.0
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:

View file

@ -7,7 +7,8 @@
import contextvars
import json
import logging
from typing import Any, ContextManager, Dict, List, Optional
from contextlib import AbstractContextManager
from typing import Any
from .utils.dynamic import instantiate_class_type
@ -17,11 +18,11 @@ log = logging.getLogger(__name__)
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
class RequestProviderDataContext(ContextManager):
class RequestProviderDataContext(AbstractContextManager):
"""Context manager for request provider data"""
def __init__(
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None
):
self.provider_data = provider_data or {}
if auth_attributes:
@ -63,7 +64,7 @@ class NeedsRequestProviderData:
return None
def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]:
def parse_request_provider_data(headers: dict[str, str]) -> dict[str, Any] | None:
"""Parse provider data from request headers"""
keys = [
"X-LlamaStack-Provider-Data",
@ -86,14 +87,14 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A
def request_provider_data_context(
headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None
) -> ContextManager:
headers: dict[str, str], auth_attributes: dict[str, list[str]] | None = None
) -> AbstractContextManager:
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
provider_data = parse_request_provider_data(headers)
return RequestProviderDataContext(provider_data, auth_attributes)
def get_auth_attributes() -> Optional[Dict[str, List[str]]]:
def get_auth_attributes() -> dict[str, list[str]] | None:
"""Helper to retrieve auth attributes from the provider data context"""
provider_data = PROVIDER_DATA_VAR.get()
if not provider_data:

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import importlib
import inspect
from typing import Any, Dict, List, Set, Tuple
from typing import Any
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
@ -58,7 +58,7 @@ class InvalidProviderError(Exception):
pass
def api_protocol_map() -> Dict[Api, Any]:
def api_protocol_map() -> dict[Api, Any]:
return {
Api.providers: ProvidersAPI,
Api.agents: Agents,
@ -83,7 +83,7 @@ def api_protocol_map() -> Dict[Api, Any]:
}
def additional_protocols_map() -> Dict[Api, Any]:
def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
@ -104,14 +104,14 @@ class ProviderWithSpec(Provider):
spec: ProviderSpec
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
ProviderRegistry = dict[Api, dict[str, ProviderSpec]]
async def resolve_impls(
run_config: StackRunConfig,
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
) -> Dict[Api, Any]:
) -> dict[Api, Any]:
"""
Resolves provider implementations by:
1. Validating and organizing providers.
@ -136,7 +136,7 @@ async def resolve_impls(
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]:
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
"""Generates specifications for automatically routed APIs."""
specs = {}
for info in builtin_automatically_routed_apis():
@ -178,10 +178,10 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
def validate_and_prepare_providers(
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api]
) -> Dict[str, Dict[str, ProviderWithSpec]]:
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: set[Api], router_apis: set[Api]
) -> dict[str, dict[str, ProviderWithSpec]]:
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
providers_with_specs: dict[str, dict[str, ProviderWithSpec]] = {}
for api_str, providers in run_config.providers.items():
api = Api(api_str)
@ -222,10 +222,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
def sort_providers_by_deps(
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig
) -> List[Tuple[str, ProviderWithSpec]]:
providers_with_specs: dict[str, dict[str, ProviderWithSpec]], run_config: StackRunConfig
) -> list[tuple[str, ProviderWithSpec]]:
"""Sorts providers based on their dependencies."""
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
sorted_providers: list[tuple[str, ProviderWithSpec]] = topological_sort(
{k: list(v.values()) for k, v in providers_with_specs.items()}
)
@ -236,11 +236,11 @@ def sort_providers_by_deps(
async def instantiate_providers(
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry
) -> Dict:
sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry
) -> dict:
"""Instantiates providers asynchronously while managing dependencies."""
impls: Dict[Api, Any] = {}
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
impls: dict[Api, Any] = {}
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies}
for a in provider.spec.optional_api_dependencies:
@ -263,9 +263,9 @@ async def instantiate_providers(
def topological_sort(
providers_with_specs: Dict[str, List[ProviderWithSpec]],
) -> List[Tuple[str, ProviderWithSpec]]:
def dfs(kv, visited: Set[str], stack: List[str]):
providers_with_specs: dict[str, list[ProviderWithSpec]],
) -> list[tuple[str, ProviderWithSpec]]:
def dfs(kv, visited: set[str], stack: list[str]):
api_str, providers = kv
visited.add(api_str)
@ -280,8 +280,8 @@ def topological_sort(
stack.append(api_str)
visited: Set[str] = set()
stack: List[str] = []
visited: set[str] = set()
stack: list[str] = []
for api_str, providers in providers_with_specs.items():
if api_str not in visited:
@ -298,8 +298,8 @@ def topological_sort(
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider: ProviderWithSpec,
deps: Dict[Api, Any],
inner_impls: Dict[str, Any],
deps: dict[Api, Any],
inner_impls: dict[str, Any],
dist_registry: DistributionRegistry,
):
protocols = api_protocol_map()
@ -391,8 +391,8 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
async def resolve_remote_stack_impls(
config: RemoteProviderConfig,
apis: List[str],
) -> Dict[Api, Any]:
apis: list[str],
) -> dict[Api, Any]:
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
from llama_stack.distribution.datatypes import RoutedProtocol
from llama_stack.distribution.store import DistributionRegistry
@ -23,7 +23,7 @@ from .routing_tables import (
async def get_routing_table_impl(
api: Api,
impls_by_provider_id: Dict[str, RoutedProtocol],
impls_by_provider_id: dict[str, RoutedProtocol],
_deps,
dist_registry: DistributionRegistry,
) -> Any:
@ -45,7 +45,7 @@ async def get_routing_table_impl(
return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any:
from .routers import (
DatasetIORouter,
EvalRouter,

View file

@ -6,12 +6,12 @@
import asyncio
import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Annotated, Any
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import Field, TypeAdapter
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import (
URL,
@ -100,9 +100,9 @@ class VectorIORouter(VectorIO):
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: Optional[int] = 384,
provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
@ -116,8 +116,8 @@ class VectorIORouter(VectorIO):
async def insert_chunks(
self,
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
@ -128,7 +128,7 @@ class VectorIORouter(VectorIO):
self,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
@ -140,7 +140,7 @@ class InferenceRouter(Inference):
def __init__(
self,
routing_table: RoutingTable,
telemetry: Optional[Telemetry] = None,
telemetry: Telemetry | None = None,
) -> None:
logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table
@ -160,10 +160,10 @@ class InferenceRouter(Inference):
async def register_model(
self,
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> None:
logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
@ -176,7 +176,7 @@ class InferenceRouter(Inference):
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricEvent]:
) -> list[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics.
Args:
@ -221,7 +221,7 @@ class InferenceRouter(Inference):
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricInResponse]:
) -> list[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
if self.telemetry:
for metric in metrics:
@ -230,9 +230,9 @@ class InferenceRouter(Inference):
async def _count_tokens(
self,
messages: List[Message] | InterleavedContent,
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> Optional[int]:
messages: list[Message] | InterleavedContent,
tool_prompt_format: ToolPromptFormat | None = None,
) -> int | None:
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:
@ -242,16 +242,16 @@ class InferenceRouter(Inference):
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = None,
tool_prompt_format: ToolPromptFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
@ -351,12 +351,12 @@ class InferenceRouter(Inference):
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
messages_batch: list[list[Message]],
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse:
logger.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
@ -376,10 +376,10 @@ class InferenceRouter(Inference):
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
@ -439,10 +439,10 @@ class InferenceRouter(Inference):
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
logger.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
@ -453,10 +453,10 @@ class InferenceRouter(Inference):
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
logger.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id)
@ -475,24 +475,24 @@ class InferenceRouter(Inference):
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
) -> OpenAICompletion:
logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
@ -531,29 +531,29 @@ class InferenceRouter(Inference):
async def openai_chat_completion(
self,
model: str,
messages: Annotated[List[OpenAIMessageParam], Field(..., min_length=1)],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
)
@ -602,7 +602,7 @@ class InferenceRouter(Inference):
provider = self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.openai_chat_completion(**params)
async def health(self) -> Dict[str, HealthResponse]:
async def health(self) -> dict[str, HealthResponse]:
health_statuses = {}
timeout = 0.5
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
@ -645,9 +645,9 @@ class SafetyRouter(Safety):
async def register_shield(
self,
shield_id: str,
provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
@ -655,8 +655,8 @@ class SafetyRouter(Safety):
async def run_shield(
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
@ -686,8 +686,8 @@ class DatasetIORouter(DatasetIO):
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> None:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
@ -702,8 +702,8 @@ class DatasetIORouter(DatasetIO):
async def iterrows(
self,
dataset_id: str,
start_index: Optional[int] = None,
limit: Optional[int] = None,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
@ -714,7 +714,7 @@ class DatasetIORouter(DatasetIO):
limit=limit,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id,
@ -741,7 +741,7 @@ class ScoringRouter(Scoring):
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
@ -762,8 +762,8 @@ class ScoringRouter(Scoring):
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None] = None,
) -> ScoreResponse:
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
@ -808,8 +808,8 @@ class EvalRouter(Eval):
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
input_rows: list[dict[str, Any]],
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
@ -863,8 +863,8 @@ class ToolRuntimeRouter(ToolRuntime):
async def query(
self,
content: InterleavedContent,
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query(
@ -873,7 +873,7 @@ class ToolRuntimeRouter(ToolRuntime):
async def insert(
self,
documents: List[RAGDocument],
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
@ -904,7 +904,7 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
@ -912,7 +912,7 @@ class ToolRuntimeRouter(ToolRuntime):
)
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -7,7 +7,7 @@
import logging
import time
import uuid
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import TypeAdapter
@ -106,20 +106,20 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
raise ValueError(f"Unregister not supported for {api}")
Registry = Dict[str, List[RoutableObjectWithProvider]]
Registry = dict[str, list[RoutableObjectWithProvider]]
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
impls_by_provider_id: Dict[str, RoutedProtocol],
impls_by_provider_id: dict[str, RoutedProtocol],
dist_registry: DistributionRegistry,
) -> None:
self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry
async def initialize(self) -> None:
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
for obj in objs:
if cls is None:
obj.provider_id = provider_id
@ -154,7 +154,7 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
@ -192,7 +192,7 @@ class CommonRoutingTableImpl(RoutingTable):
raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
# Get from disk registry
obj = await self.dist_registry.get(type, identifier)
if not obj:
@ -236,7 +236,7 @@ class CommonRoutingTableImpl(RoutingTable):
await self.dist_registry.register(obj)
return obj
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
filtered_objs = [obj for obj in objs if obj.type == type]
@ -277,10 +277,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def register_model(
self,
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
if provider_model_id is None:
provider_model_id = model_id
@ -328,9 +328,9 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def register_shield(
self,
shield_id: str,
provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
if provider_shield_id is None:
provider_shield_id = shield_id
@ -368,9 +368,9 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: Optional[int] = 384,
provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB:
if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id
@ -423,8 +423,8 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> Dataset:
if isinstance(source, dict):
if source["type"] == "uri":
@ -489,9 +489,9 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[ScoringFnParams] = None,
provider_scoring_fn_id: str | None = None,
provider_id: str | None = None,
params: ScoringFnParams | None = None,
) -> None:
if provider_scoring_fn_id is None:
provider_scoring_fn_id = scoring_fn_id
@ -528,10 +528,10 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: List[str],
metadata: Optional[Dict[str, Any]] = None,
provider_benchmark_id: Optional[str] = None,
provider_id: Optional[str] = None,
scoring_functions: list[str],
metadata: dict[str, Any] | None = None,
provider_benchmark_id: str | None = None,
provider_id: str | None = None,
) -> None:
if metadata is None:
metadata = {}
@ -556,7 +556,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
tools = await self.get_all_with_type("tool")
if toolgroup_id:
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
@ -578,8 +578,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
self,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: Optional[URL] = None,
args: Optional[Dict[str, Any]] = None,
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)

View file

@ -7,7 +7,6 @@
import json
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, List, Optional
from urllib.parse import parse_qs
import httpx
@ -22,7 +21,7 @@ logger = get_logger(name=__name__, category="auth")
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
access_attributes: Optional[AccessAttributes] = Field(
access_attributes: AccessAttributes | None = Field(
default=None,
description="""
Structured user attributes for attribute-based access control.
@ -44,7 +43,7 @@ class AuthResponse(BaseModel):
""",
)
message: Optional[str] = Field(
message: str | None = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
@ -52,9 +51,9 @@ class AuthResponse(BaseModel):
class AuthRequestContext(BaseModel):
path: str = Field(description="The path of the request being authenticated")
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: Dict[str, List[str]] = Field(
params: dict[str, list[str]] = Field(
description="Query parameters from the original request, parsed as dictionary of lists"
)
@ -76,14 +75,14 @@ class AuthProviderConfig(BaseModel):
"""Base configuration for authentication providers."""
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
config: Dict[str, str] = Field(..., description="Provider-specific configuration")
config: dict[str, str] = Field(..., description="Provider-specific configuration")
class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@abstractmethod
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
"""Validate a token and return access attributes."""
pass
@ -96,7 +95,7 @@ class AuthProvider(ABC):
class KubernetesAuthProvider(AuthProvider):
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
def __init__(self, config: Dict[str, str]):
def __init__(self, config: dict[str, str]):
self.api_server_url = config["api_server_url"]
self.ca_cert_path = config.get("ca_cert_path")
self._client = None
@ -120,7 +119,7 @@ class KubernetesAuthProvider(AuthProvider):
self._client = ApiClient(configuration)
return self._client
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
"""Validate a Kubernetes token and return access attributes."""
try:
client = await self._get_client()
@ -166,11 +165,11 @@ class KubernetesAuthProvider(AuthProvider):
class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint."""
def __init__(self, config: Dict[str, str]):
def __init__(self, config: dict[str, str]):
self.endpoint = config["endpoint"]
self._client = None
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
"""Validate a token using the custom authentication endpoint."""
if not self.endpoint:
raise ValueError("Authentication endpoint not configured")

View file

@ -6,7 +6,6 @@
import inspect
import re
from typing import Dict, List
from pydantic import BaseModel
@ -29,7 +28,7 @@ def toolgroup_protocol_map():
}
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
apis = {}
protocols = api_protocol_map()

View file

@ -15,7 +15,7 @@ import warnings
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Any, List, Optional, Union
from typing import Annotated, Any
import yaml
from fastapi import Body, FastAPI, HTTPException, Request
@ -24,7 +24,6 @@ from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
@ -91,7 +90,7 @@ async def global_exception_handler(request: Request, exc: Exception):
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
def translate_exception(exc: Exception) -> HTTPException | RequestValidationError:
if isinstance(exc, ValidationError):
exc = RequestValidationError(exc.errors())
@ -315,7 +314,7 @@ class ClientVersionMiddleware:
return await self.app(scope, receive, send)
def main(args: Optional[argparse.Namespace] = None):
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
@ -385,7 +384,7 @@ def main(args: Optional[argparse.Namespace] = None):
raise ValueError("Either --yaml-config or --template must be provided")
logger_config = None
with open(config_file, "r") as fp:
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
@ -517,7 +516,7 @@ def main(args: Optional[argparse.Namespace] = None):
uvicorn.run(**uvicorn_config)
def extract_path_params(route: str) -> List[str]:
def extract_path_params(route: str) -> list[str]:
segments = route.split("/")
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
# to handle path params like {param:path}

View file

@ -8,7 +8,7 @@ import importlib.resources
import os
import re
import tempfile
from typing import Any, Dict, Optional
from typing import Any
import yaml
@ -90,7 +90,7 @@ RESOURCES = [
]
async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc)
if api not in impls:
@ -197,7 +197,7 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
) from e
def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConfig) -> None:
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
"""Add internal implementations (inspect and providers) to the implementations dictionary.
Args:
@ -220,8 +220,8 @@ def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConf
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
) -> Dict[Api, Any]:
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
) -> dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
@ -244,7 +244,7 @@ def get_stack_run_config_from_template(template: str) -> StackRunConfig:
def run_config_from_adhoc_config_spec(
adhoc_config_spec: str, provider_registry: Optional[ProviderRegistry] = None
adhoc_config_spec: str, provider_registry: ProviderRegistry | None = None
) -> StackRunConfig:
"""
Create an adhoc distribution from a list of API providers.

View file

@ -6,7 +6,7 @@
import asyncio
from contextlib import asynccontextmanager
from typing import Dict, List, Optional, Protocol, Tuple
from typing import Protocol
import pydantic
@ -20,13 +20,13 @@ logger = get_logger(__name__, category="core")
class DistributionRegistry(Protocol):
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
async def get_all(self) -> list[RoutableObjectWithProvider]: ...
async def initialize(self) -> None: ...
async def get(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
async def get(self, identifier: str) -> RoutableObjectWithProvider | None: ...
def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
def get_cached(self, identifier: str) -> RoutableObjectWithProvider | None: ...
async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...
@ -40,13 +40,13 @@ KEY_VERSION = "v8"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
def _get_registry_key_range() -> Tuple[str, str]:
def _get_registry_key_range() -> tuple[str, str]:
"""Returns the start and end keys for the registry range query."""
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
return start_key, f"{start_key}\xff"
def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider]:
def _parse_registry_values(values: list[str]) -> list[RoutableObjectWithProvider]:
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = []
for value in values:
@ -67,16 +67,16 @@ class DiskDistributionRegistry(DistributionRegistry):
async def initialize(self) -> None:
pass
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
# Disk registry does not have a cache
raise NotImplementedError("Disk registry does not have a cache")
async def get_all(self) -> List[RoutableObjectWithProvider]:
async def get_all(self) -> list[RoutableObjectWithProvider]:
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key)
return _parse_registry_values(values)
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier))
if not json_str:
return None
@ -113,7 +113,7 @@ class DiskDistributionRegistry(DistributionRegistry):
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
def __init__(self, kvstore: KVStore):
super().__init__(kvstore)
self.cache: Dict[Tuple[str, str], RoutableObjectWithProvider] = {}
self.cache: dict[tuple[str, str], RoutableObjectWithProvider] = {}
self._initialized = False
self._initialize_lock = asyncio.Lock()
self._cache_lock = asyncio.Lock()
@ -147,15 +147,15 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
async def initialize(self) -> None:
await self._ensure_initialized()
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
return self.cache.get((type, identifier), None)
async def get_all(self) -> List[RoutableObjectWithProvider]:
async def get_all(self) -> list[RoutableObjectWithProvider]:
await self._ensure_initialized()
async with self._locked_cache() as cache:
return list(cache.values())
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
await self._ensure_initialized()
cache_key = (type, identifier)
@ -189,7 +189,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
async def create_dist_registry(
metadata_store: Optional[KVStoreConfig],
metadata_store: KVStoreConfig | None,
image_name: str,
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import os
from typing import Optional
from llama_stack_client import LlamaStackClient
@ -23,7 +22,7 @@ class LlamaStackApi:
},
)
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]):
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: dict | None):
"""Run scoring on a single row"""
if not scoring_params:
scoring_params = {fn_id: None for fn_id in scoring_function_ids}

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from typing import Any
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
def redact_sensitive_fields(data: dict[str, Any]) -> dict[str, Any]:
"""Redact sensitive information from config before printing."""
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
@ -18,7 +18,7 @@ def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
return [_redact_value(i) for i in v]
return v
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
def _redact_dict(d: dict[str, Any]) -> dict[str, Any]:
result = {}
for k, v in d.items():
if any(pattern in k.lower() for pattern in sensitive_patterns):

View file

@ -4,14 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from contextvars import ContextVar
from typing import AsyncGenerator, List, TypeVar
from typing import TypeVar
T = TypeVar("T")
def preserve_contexts_async_generator(
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
gen: AsyncGenerator[T, None], context_vars: list[ContextVar]
) -> AsyncGenerator[T, None]:
"""
Wraps an async generator to preserve context variables across iterations.

View file

@ -8,12 +8,11 @@ import inspect
import json
import logging
from enum import Enum
from typing import Any, List, Literal, Optional, Type, Union, get_args, get_origin
from typing import Annotated, Any, Literal, Union, get_args, get_origin
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefinedType
from typing_extensions import Annotated
log = logging.getLogger(__name__)
@ -21,7 +20,7 @@ log = logging.getLogger(__name__)
def is_list_of_primitives(field_type):
"""Check if a field type is a List of primitive types."""
origin = get_origin(field_type)
if origin is List or origin is list:
if origin is list or origin is list:
args = get_args(field_type)
if len(args) == 1 and args[0] in (int, float, str, bool):
return True
@ -53,7 +52,7 @@ def get_non_none_type(field_type):
return next(arg for arg in get_args(field_type) if arg is not type(None))
def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any):
def manually_validate_field(model: type[BaseModel], field_name: str, value: Any):
validators = model.__pydantic_decorators__.field_validators
for _name, validator in validators.items():
if field_name in validator.info.fields:
@ -126,7 +125,7 @@ def prompt_for_discriminated_union(
#
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
# unit tests for coverage.
def prompt_for_config(config_type: type[BaseModel], existing_config: Optional[BaseModel] = None) -> BaseModel:
def prompt_for_config(config_type: type[BaseModel], existing_config: BaseModel | None = None) -> BaseModel:
"""
Recursively prompt the user for configuration values based on a Pydantic BaseModel.