forked from phoenix-oss/llama-stack-mirror
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -14,8 +14,8 @@ logger = get_logger(__name__, category="core")
|
|||
|
||||
def check_access(
|
||||
obj_identifier: str,
|
||||
obj_attributes: Optional[AccessAttributes],
|
||||
user_attributes: Optional[Dict[str, Any]] = None,
|
||||
obj_attributes: AccessAttributes | None,
|
||||
user_attributes: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""Check if the current user has access to the given object, based on access attributes.
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import inspect
|
|||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import Enum
|
||||
from typing import Any, Type, Union, get_args, get_origin
|
||||
from typing import Any, Union, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
|
@ -27,7 +27,7 @@ async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
|
|||
return impl
|
||||
|
||||
|
||||
def create_api_client_class(protocol) -> Type:
|
||||
def create_api_client_class(protocol) -> type:
|
||||
if protocol in _CLIENT_CLASSES:
|
||||
return _CLIENT_CLASSES[protocol]
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
import logging
|
||||
import textwrap
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
|
@ -24,7 +24,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def configure_single_provider(registry: Dict[str, ProviderSpec], provider: Provider) -> Provider:
|
||||
def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider:
|
||||
provider_spec = registry[provider.provider_type]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
|
@ -120,8 +120,8 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
|
|||
|
||||
|
||||
def upgrade_from_routing_table(
|
||||
config_dict: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
config_dict: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
def get_providers(entries):
|
||||
return [
|
||||
Provider(
|
||||
|
@ -163,7 +163,7 @@ def upgrade_from_routing_table(
|
|||
return config_dict
|
||||
|
||||
|
||||
def parse_and_maybe_upgrade_config(config_dict: Dict[str, Any]) -> StackRunConfig:
|
||||
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
||||
version = config_dict.get("version", None)
|
||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||
return StackRunConfig(**config_dict)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Optional, Union
|
||||
from typing import Annotated, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -30,7 +30,7 @@ LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
|||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||
|
||||
|
||||
RoutingKey = Union[str, List[str]]
|
||||
RoutingKey = str | list[str]
|
||||
|
||||
|
||||
class AccessAttributes(BaseModel):
|
||||
|
@ -47,17 +47,17 @@ class AccessAttributes(BaseModel):
|
|||
"""
|
||||
|
||||
# Standard attribute categories - the minimal set we need now
|
||||
roles: Optional[List[str]] = Field(
|
||||
roles: list[str] | None = Field(
|
||||
default=None, description="Role-based attributes (e.g., 'admin', 'data-scientist', 'user')"
|
||||
)
|
||||
|
||||
teams: Optional[List[str]] = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
||||
teams: list[str] | None = Field(default=None, description="Team-based attributes (e.g., 'ml-team', 'nlp-team')")
|
||||
|
||||
projects: Optional[List[str]] = Field(
|
||||
projects: list[str] | None = Field(
|
||||
default=None, description="Project-based access attributes (e.g., 'llama-3', 'customer-insights')"
|
||||
)
|
||||
|
||||
namespaces: Optional[List[str]] = Field(
|
||||
namespaces: list[str] | None = Field(
|
||||
default=None, description="Namespace-based access control for resource isolation"
|
||||
)
|
||||
|
||||
|
@ -106,7 +106,7 @@ class ResourceWithACL(Resource):
|
|||
# ^ User must have access to the customer-insights project AND have confidential namespace
|
||||
"""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = None
|
||||
access_attributes: AccessAttributes | None = None
|
||||
|
||||
|
||||
# Use the extended Resource for all routable objects
|
||||
|
@ -142,41 +142,21 @@ class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
|||
pass
|
||||
|
||||
|
||||
RoutableObject = Union[
|
||||
Model,
|
||||
Shield,
|
||||
VectorDB,
|
||||
Dataset,
|
||||
ScoringFn,
|
||||
Benchmark,
|
||||
Tool,
|
||||
ToolGroup,
|
||||
]
|
||||
|
||||
RoutableObject = Model | Shield | VectorDB | Dataset | ScoringFn | Benchmark | Tool | ToolGroup
|
||||
|
||||
RoutableObjectWithProvider = Annotated[
|
||||
Union[
|
||||
ModelWithACL,
|
||||
ShieldWithACL,
|
||||
VectorDBWithACL,
|
||||
DatasetWithACL,
|
||||
ScoringFnWithACL,
|
||||
BenchmarkWithACL,
|
||||
ToolWithACL,
|
||||
ToolGroupWithACL,
|
||||
],
|
||||
ModelWithACL
|
||||
| ShieldWithACL
|
||||
| VectorDBWithACL
|
||||
| DatasetWithACL
|
||||
| ScoringFnWithACL
|
||||
| BenchmarkWithACL
|
||||
| ToolWithACL
|
||||
| ToolGroupWithACL,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
RoutedProtocol = Union[
|
||||
Inference,
|
||||
Safety,
|
||||
VectorIO,
|
||||
DatasetIO,
|
||||
Scoring,
|
||||
Eval,
|
||||
ToolRuntime,
|
||||
]
|
||||
RoutedProtocol = Inference | Safety | VectorIO | DatasetIO | Scoring | Eval | ToolRuntime
|
||||
|
||||
|
||||
# Example: /inference, /safety
|
||||
|
@ -184,15 +164,15 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
|||
provider_type: str = "router"
|
||||
config_class: str = ""
|
||||
|
||||
container_image: Optional[str] = None
|
||||
container_image: str | None = None
|
||||
routing_table_api: Api
|
||||
module: str
|
||||
provider_data_validator: Optional[str] = Field(
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
def pip_packages(self) -> list[str]:
|
||||
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
||||
|
||||
|
||||
|
@ -200,20 +180,20 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
|||
class RoutingTableProviderSpec(ProviderSpec):
|
||||
provider_type: str = "routing_table"
|
||||
config_class: str = ""
|
||||
container_image: Optional[str] = None
|
||||
container_image: str | None = None
|
||||
|
||||
router_api: Api
|
||||
module: str
|
||||
pip_packages: List[str] = Field(default_factory=list)
|
||||
pip_packages: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DistributionSpec(BaseModel):
|
||||
description: Optional[str] = Field(
|
||||
description: str | None = Field(
|
||||
default="",
|
||||
description="Description of the distribution",
|
||||
)
|
||||
container_image: Optional[str] = None
|
||||
providers: Dict[str, Union[str, List[str]]] = Field(
|
||||
container_image: str | None = None
|
||||
providers: dict[str, str | list[str]] = Field(
|
||||
default_factory=dict,
|
||||
description="""
|
||||
Provider Types for each of the APIs provided by this distribution. If you
|
||||
|
@ -225,12 +205,12 @@ in the runtime configuration to help route to the correct provider.""",
|
|||
class Provider(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
category_levels: Dict[str, str] = Field(
|
||||
default_factory=Dict,
|
||||
category_levels: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="""
|
||||
Dictionary of different logging configurations for different portions (ex: core, server) of llama stack""",
|
||||
)
|
||||
|
@ -248,7 +228,7 @@ class AuthenticationConfig(BaseModel):
|
|||
...,
|
||||
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
|
||||
)
|
||||
config: Dict[str, str] = Field(
|
||||
config: dict[str, str] = Field(
|
||||
...,
|
||||
description="Provider-specific configuration",
|
||||
)
|
||||
|
@ -261,15 +241,15 @@ class ServerConfig(BaseModel):
|
|||
ge=1024,
|
||||
le=65535,
|
||||
)
|
||||
tls_certfile: Optional[str] = Field(
|
||||
tls_certfile: str | None = Field(
|
||||
default=None,
|
||||
description="Path to TLS certificate file for HTTPS",
|
||||
)
|
||||
tls_keyfile: Optional[str] = Field(
|
||||
tls_keyfile: str | None = Field(
|
||||
default=None,
|
||||
description="Path to TLS key file for HTTPS",
|
||||
)
|
||||
auth: Optional[AuthenticationConfig] = Field(
|
||||
auth: AuthenticationConfig | None = Field(
|
||||
default=None,
|
||||
description="Authentication configuration for the server",
|
||||
)
|
||||
|
@ -285,23 +265,23 @@ Reference to the distribution this package refers to. For unregistered (adhoc) p
|
|||
this could be just a hash
|
||||
""",
|
||||
)
|
||||
container_image: Optional[str] = Field(
|
||||
container_image: str | None = Field(
|
||||
default=None,
|
||||
description="Reference to the container image if this package refers to a container",
|
||||
)
|
||||
apis: List[str] = Field(
|
||||
apis: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="""
|
||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||
)
|
||||
|
||||
providers: Dict[str, List[Provider]] = Field(
|
||||
providers: dict[str, list[Provider]] = Field(
|
||||
description="""
|
||||
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
|
||||
can be instantiated multiple times (with different configs) if necessary.
|
||||
""",
|
||||
)
|
||||
metadata_store: Optional[KVStoreConfig] = Field(
|
||||
metadata_store: KVStoreConfig | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Configuration for the persistence store used by the distribution registry. If not specified,
|
||||
|
@ -309,22 +289,22 @@ a default SQLite store will be used.""",
|
|||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
models: List[ModelInput] = Field(default_factory=list)
|
||||
shields: List[ShieldInput] = Field(default_factory=list)
|
||||
vector_dbs: List[VectorDBInput] = Field(default_factory=list)
|
||||
datasets: List[DatasetInput] = Field(default_factory=list)
|
||||
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
|
||||
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
|
||||
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
|
||||
models: list[ModelInput] = Field(default_factory=list)
|
||||
shields: list[ShieldInput] = Field(default_factory=list)
|
||||
vector_dbs: list[VectorDBInput] = Field(default_factory=list)
|
||||
datasets: list[DatasetInput] = Field(default_factory=list)
|
||||
scoring_fns: list[ScoringFnInput] = Field(default_factory=list)
|
||||
benchmarks: list[BenchmarkInput] = Field(default_factory=list)
|
||||
tool_groups: list[ToolGroupInput] = Field(default_factory=list)
|
||||
|
||||
logging: Optional[LoggingConfig] = Field(default=None, description="Configuration for Llama Stack Logging")
|
||||
logging: LoggingConfig | None = Field(default=None, description="Configuration for Llama Stack Logging")
|
||||
|
||||
server: ServerConfig = Field(
|
||||
default_factory=ServerConfig,
|
||||
description="Configuration for the HTTP(S) server",
|
||||
)
|
||||
|
||||
external_providers_dir: Optional[str] = Field(
|
||||
external_providers_dir: str | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
@ -338,11 +318,11 @@ class BuildConfig(BaseModel):
|
|||
default="conda",
|
||||
description="Type of package to build (conda | container | venv)",
|
||||
)
|
||||
image_name: Optional[str] = Field(
|
||||
image_name: str | None = Field(
|
||||
default=None,
|
||||
description="Name of the distribution to build",
|
||||
)
|
||||
external_providers_dir: Optional[str] = Field(
|
||||
external_providers_dir: str | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
||||
"pip_packages MUST contain the provider package name.",
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import glob
|
||||
import importlib
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
@ -24,7 +24,7 @@ from llama_stack.providers.datatypes import (
|
|||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def stack_apis() -> List[Api]:
|
||||
def stack_apis() -> list[Api]:
|
||||
return list(Api)
|
||||
|
||||
|
||||
|
@ -33,7 +33,7 @@ class AutoRoutedApiInfo(BaseModel):
|
|||
router_api: Api
|
||||
|
||||
|
||||
def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||
def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
||||
return [
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.models,
|
||||
|
@ -66,12 +66,12 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
|||
]
|
||||
|
||||
|
||||
def providable_apis() -> List[Api]:
|
||||
def providable_apis() -> list[Api]:
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||
|
||||
|
||||
def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderSpec:
|
||||
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
|
||||
adapter = AdapterSpec(**spec_data["adapter"])
|
||||
spec = remote_provider_spec(
|
||||
api=api,
|
||||
|
@ -81,7 +81,7 @@ def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderS
|
|||
return spec
|
||||
|
||||
|
||||
def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||
spec = InlineProviderSpec(
|
||||
api=api,
|
||||
provider_type=f"inline::{provider_name}",
|
||||
|
@ -98,7 +98,7 @@ def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_nam
|
|||
|
||||
def get_provider_registry(
|
||||
config=None,
|
||||
) -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
"""Get the provider registry, optionally including external providers.
|
||||
|
||||
This function loads both built-in providers and external providers from YAML files.
|
||||
|
@ -133,7 +133,7 @@ def get_provider_registry(
|
|||
ValueError: If any provider spec is invalid
|
||||
"""
|
||||
|
||||
ret: Dict[Api, Dict[str, ProviderSpec]] = {}
|
||||
ret: dict[Api, dict[str, ProviderSpec]] = {}
|
||||
for api in providable_apis():
|
||||
name = api.name.lower()
|
||||
logger.debug(f"Importing module {name}")
|
||||
|
|
|
@ -12,7 +12,7 @@ import os
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
|
||||
from typing import Any, TypeVar, Union, get_args, get_origin
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
|
@ -119,8 +119,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
self,
|
||||
config_path_or_template_name: str,
|
||||
skip_logger_removal: bool = False,
|
||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||
provider_data: Optional[dict[str, Any]] = None,
|
||||
custom_provider_registry: ProviderRegistry | None = None,
|
||||
provider_data: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||
|
@ -181,8 +181,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
def __init__(
|
||||
self,
|
||||
config_path_or_template_name: str,
|
||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||
provider_data: Optional[dict[str, Any]] = None,
|
||||
custom_provider_registry: ProviderRegistry | None = None,
|
||||
provider_data: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
# when using the library client, we should not log to console since many
|
||||
|
@ -371,7 +371,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return await response.parse()
|
||||
|
||||
def _convert_body(self, path: str, method: str, body: Optional[dict] = None) -> dict:
|
||||
def _convert_body(self, path: str, method: str, body: dict | None = None) -> dict:
|
||||
if not body:
|
||||
return {}
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -73,14 +73,14 @@ class ProviderImpl(Providers):
|
|||
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]:
|
||||
async def get_providers_health(self) -> dict[str, dict[str, HealthResponse]]:
|
||||
"""Get health status for all providers.
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
|
||||
Each API maps to a dictionary of provider IDs to their health responses.
|
||||
"""
|
||||
providers_health: Dict[str, Dict[str, HealthResponse]] = {}
|
||||
providers_health: dict[str, dict[str, HealthResponse]] = {}
|
||||
timeout = 1.0
|
||||
|
||||
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
|
||||
|
|
|
@ -7,7 +7,8 @@
|
|||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, ContextManager, Dict, List, Optional
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
|
@ -17,11 +18,11 @@ log = logging.getLogger(__name__)
|
|||
PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None)
|
||||
|
||||
|
||||
class RequestProviderDataContext(ContextManager):
|
||||
class RequestProviderDataContext(AbstractContextManager):
|
||||
"""Context manager for request provider data"""
|
||||
|
||||
def __init__(
|
||||
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||
self, provider_data: dict[str, Any] | None = None, auth_attributes: dict[str, list[str]] | None = None
|
||||
):
|
||||
self.provider_data = provider_data or {}
|
||||
if auth_attributes:
|
||||
|
@ -63,7 +64,7 @@ class NeedsRequestProviderData:
|
|||
return None
|
||||
|
||||
|
||||
def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, Any]]:
|
||||
def parse_request_provider_data(headers: dict[str, str]) -> dict[str, Any] | None:
|
||||
"""Parse provider data from request headers"""
|
||||
keys = [
|
||||
"X-LlamaStack-Provider-Data",
|
||||
|
@ -86,14 +87,14 @@ def parse_request_provider_data(headers: Dict[str, str]) -> Optional[Dict[str, A
|
|||
|
||||
|
||||
def request_provider_data_context(
|
||||
headers: Dict[str, str], auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||
) -> ContextManager:
|
||||
headers: dict[str, str], auth_attributes: dict[str, list[str]] | None = None
|
||||
) -> AbstractContextManager:
|
||||
"""Context manager that sets request provider data from headers and auth attributes for the duration of the context"""
|
||||
provider_data = parse_request_provider_data(headers)
|
||||
return RequestProviderDataContext(provider_data, auth_attributes)
|
||||
|
||||
|
||||
def get_auth_attributes() -> Optional[Dict[str, List[str]]]:
|
||||
def get_auth_attributes() -> dict[str, list[str]] | None:
|
||||
"""Helper to retrieve auth attributes from the provider data context"""
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
if not provider_data:
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
|
@ -58,7 +58,7 @@ class InvalidProviderError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def api_protocol_map() -> Dict[Api, Any]:
|
||||
def api_protocol_map() -> dict[Api, Any]:
|
||||
return {
|
||||
Api.providers: ProvidersAPI,
|
||||
Api.agents: Agents,
|
||||
|
@ -83,7 +83,7 @@ def api_protocol_map() -> Dict[Api, Any]:
|
|||
}
|
||||
|
||||
|
||||
def additional_protocols_map() -> Dict[Api, Any]:
|
||||
def additional_protocols_map() -> dict[Api, Any]:
|
||||
return {
|
||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
|
||||
|
@ -104,14 +104,14 @@ class ProviderWithSpec(Provider):
|
|||
spec: ProviderSpec
|
||||
|
||||
|
||||
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
||||
ProviderRegistry = dict[Api, dict[str, ProviderSpec]]
|
||||
|
||||
|
||||
async def resolve_impls(
|
||||
run_config: StackRunConfig,
|
||||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> Dict[Api, Any]:
|
||||
) -> dict[Api, Any]:
|
||||
"""
|
||||
Resolves provider implementations by:
|
||||
1. Validating and organizing providers.
|
||||
|
@ -136,7 +136,7 @@ async def resolve_impls(
|
|||
return await instantiate_providers(sorted_providers, router_apis, dist_registry)
|
||||
|
||||
|
||||
def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
||||
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||
"""Generates specifications for automatically routed APIs."""
|
||||
specs = {}
|
||||
for info in builtin_automatically_routed_apis():
|
||||
|
@ -178,10 +178,10 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str,
|
|||
|
||||
|
||||
def validate_and_prepare_providers(
|
||||
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api]
|
||||
) -> Dict[str, Dict[str, ProviderWithSpec]]:
|
||||
run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: set[Api], router_apis: set[Api]
|
||||
) -> dict[str, dict[str, ProviderWithSpec]]:
|
||||
"""Validates providers, handles deprecations, and organizes them into a spec dictionary."""
|
||||
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {}
|
||||
providers_with_specs: dict[str, dict[str, ProviderWithSpec]] = {}
|
||||
|
||||
for api_str, providers in run_config.providers.items():
|
||||
api = Api(api_str)
|
||||
|
@ -222,10 +222,10 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR
|
|||
|
||||
|
||||
def sort_providers_by_deps(
|
||||
providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig
|
||||
) -> List[Tuple[str, ProviderWithSpec]]:
|
||||
providers_with_specs: dict[str, dict[str, ProviderWithSpec]], run_config: StackRunConfig
|
||||
) -> list[tuple[str, ProviderWithSpec]]:
|
||||
"""Sorts providers based on their dependencies."""
|
||||
sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort(
|
||||
sorted_providers: list[tuple[str, ProviderWithSpec]] = topological_sort(
|
||||
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||
)
|
||||
|
||||
|
@ -236,11 +236,11 @@ def sort_providers_by_deps(
|
|||
|
||||
|
||||
async def instantiate_providers(
|
||||
sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry
|
||||
) -> Dict:
|
||||
sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry
|
||||
) -> dict:
|
||||
"""Instantiates providers asynchronously while managing dependencies."""
|
||||
impls: Dict[Api, Any] = {}
|
||||
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||
impls: dict[Api, Any] = {}
|
||||
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||
for api_str, provider in sorted_providers:
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||
for a in provider.spec.optional_api_dependencies:
|
||||
|
@ -263,9 +263,9 @@ async def instantiate_providers(
|
|||
|
||||
|
||||
def topological_sort(
|
||||
providers_with_specs: Dict[str, List[ProviderWithSpec]],
|
||||
) -> List[Tuple[str, ProviderWithSpec]]:
|
||||
def dfs(kv, visited: Set[str], stack: List[str]):
|
||||
providers_with_specs: dict[str, list[ProviderWithSpec]],
|
||||
) -> list[tuple[str, ProviderWithSpec]]:
|
||||
def dfs(kv, visited: set[str], stack: list[str]):
|
||||
api_str, providers = kv
|
||||
visited.add(api_str)
|
||||
|
||||
|
@ -280,8 +280,8 @@ def topological_sort(
|
|||
|
||||
stack.append(api_str)
|
||||
|
||||
visited: Set[str] = set()
|
||||
stack: List[str] = []
|
||||
visited: set[str] = set()
|
||||
stack: list[str] = []
|
||||
|
||||
for api_str, providers in providers_with_specs.items():
|
||||
if api_str not in visited:
|
||||
|
@ -298,8 +298,8 @@ def topological_sort(
|
|||
# returns a class implementing the protocol corresponding to the Api
|
||||
async def instantiate_provider(
|
||||
provider: ProviderWithSpec,
|
||||
deps: Dict[Api, Any],
|
||||
inner_impls: Dict[str, Any],
|
||||
deps: dict[Api, Any],
|
||||
inner_impls: dict[str, Any],
|
||||
dist_registry: DistributionRegistry,
|
||||
):
|
||||
protocols = api_protocol_map()
|
||||
|
@ -391,8 +391,8 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
|
||||
async def resolve_remote_stack_impls(
|
||||
config: RemoteProviderConfig,
|
||||
apis: List[str],
|
||||
) -> Dict[Api, Any]:
|
||||
apis: list[str],
|
||||
) -> dict[Api, Any]:
|
||||
protocols = api_protocol_map()
|
||||
additional_protocols = additional_protocols_map()
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutedProtocol
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
|
@ -23,7 +23,7 @@ from .routing_tables import (
|
|||
|
||||
async def get_routing_table_impl(
|
||||
api: Api,
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||
_deps,
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> Any:
|
||||
|
@ -45,7 +45,7 @@ async def get_routing_table_impl(
|
|||
return impl
|
||||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any:
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any:
|
||||
from .routers import (
|
||||
DatasetIORouter,
|
||||
EvalRouter,
|
||||
|
|
|
@ -6,12 +6,12 @@
|
|||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Annotated, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import Field, TypeAdapter
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
|
@ -100,9 +100,9 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: Optional[int] = 384,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> None:
|
||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||
await self.routing_table.register_vector_db(
|
||||
|
@ -116,8 +116,8 @@ class VectorIORouter(VectorIO):
|
|||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
chunks: List[Chunk],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||
|
@ -128,7 +128,7 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
||||
|
@ -140,7 +140,7 @@ class InferenceRouter(Inference):
|
|||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
telemetry: Optional[Telemetry] = None,
|
||||
telemetry: Telemetry | None = None,
|
||||
) -> None:
|
||||
logger.debug("Initializing InferenceRouter")
|
||||
self.routing_table = routing_table
|
||||
|
@ -160,10 +160,10 @@ class InferenceRouter(Inference):
|
|||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
provider_model_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
||||
|
@ -176,7 +176,7 @@ class InferenceRouter(Inference):
|
|||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricEvent]:
|
||||
) -> list[MetricEvent]:
|
||||
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
||||
|
||||
Args:
|
||||
|
@ -221,7 +221,7 @@ class InferenceRouter(Inference):
|
|||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> List[MetricInResponse]:
|
||||
) -> list[MetricInResponse]:
|
||||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
|
@ -230,9 +230,9 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
messages: List[Message] | InterleavedContent,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
) -> Optional[int]:
|
||||
messages: list[Message] | InterleavedContent,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
) -> int | None:
|
||||
if isinstance(messages, list):
|
||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||
else:
|
||||
|
@ -242,16 +242,16 @@ class InferenceRouter(Inference):
|
|||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = None,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = None,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
|
||||
)
|
||||
|
@ -351,12 +351,12 @@ class InferenceRouter(Inference):
|
|||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: List[List[Message]],
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
messages_batch: list[list[Message]],
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
logger.debug(
|
||||
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
|
@ -376,10 +376,10 @@ class InferenceRouter(Inference):
|
|||
self,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
@ -439,10 +439,10 @@ class InferenceRouter(Inference):
|
|||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
logger.debug(
|
||||
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
|
@ -453,10 +453,10 @@ class InferenceRouter(Inference):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||
output_dimension: Optional[int] = None,
|
||||
task_type: Optional[EmbeddingTaskType] = None,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
|
@ -475,24 +475,24 @@ class InferenceRouter(Inference):
|
|||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
) -> OpenAICompletion:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||
|
@ -531,29 +531,29 @@ class InferenceRouter(Inference):
|
|||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: Annotated[List[OpenAIMessageParam], Field(..., min_length=1)],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||
)
|
||||
|
@ -602,7 +602,7 @@ class InferenceRouter(Inference):
|
|||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_chat_completion(**params)
|
||||
|
||||
async def health(self) -> Dict[str, HealthResponse]:
|
||||
async def health(self) -> dict[str, HealthResponse]:
|
||||
health_statuses = {}
|
||||
timeout = 0.5
|
||||
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
||||
|
@ -645,9 +645,9 @@ class SafetyRouter(Safety):
|
|||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
provider_shield_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Shield:
|
||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||
|
@ -655,8 +655,8 @@ class SafetyRouter(Safety):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||
|
@ -686,8 +686,8 @@ class DatasetIORouter(DatasetIO):
|
|||
self,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
dataset_id: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
dataset_id: str | None = None,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
||||
|
@ -702,8 +702,8 @@ class DatasetIORouter(DatasetIO):
|
|||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
start_index: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
logger.debug(
|
||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||
|
@ -714,7 +714,7 @@ class DatasetIORouter(DatasetIO):
|
|||
limit=limit,
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||
dataset_id=dataset_id,
|
||||
|
@ -741,7 +741,7 @@ class ScoringRouter(Scoring):
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||
|
@ -762,8 +762,8 @@ class ScoringRouter(Scoring):
|
|||
|
||||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||
) -> ScoreResponse:
|
||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||
res = {}
|
||||
|
@ -808,8 +808,8 @@ class EvalRouter(Eval):
|
|||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
input_rows: list[dict[str, Any]],
|
||||
scoring_functions: list[str],
|
||||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||
|
@ -863,8 +863,8 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
vector_db_ids: List[str],
|
||||
query_config: Optional[RAGQueryConfig] = None,
|
||||
vector_db_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||
|
@ -873,7 +873,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
|
||||
async def insert(
|
||||
self,
|
||||
documents: List[RAGDocument],
|
||||
documents: list[RAGDocument],
|
||||
vector_db_id: str,
|
||||
chunk_size_in_tokens: int = 512,
|
||||
) -> None:
|
||||
|
@ -904,7 +904,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
logger.debug("ToolRuntimeRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||
tool_name=tool_name,
|
||||
|
@ -912,7 +912,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
|
@ -106,20 +106,20 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|||
raise ValueError(f"Unregister not supported for {api}")
|
||||
|
||||
|
||||
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
||||
Registry = dict[str, list[RoutableObjectWithProvider]]
|
||||
|
||||
|
||||
class CommonRoutingTableImpl(RoutingTable):
|
||||
def __init__(
|
||||
self,
|
||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> None:
|
||||
self.impls_by_provider_id = impls_by_provider_id
|
||||
self.dist_registry = dist_registry
|
||||
|
||||
async def initialize(self) -> None:
|
||||
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||
for obj in objs:
|
||||
if cls is None:
|
||||
obj.provider_id = provider_id
|
||||
|
@ -154,7 +154,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
|
||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
def apiname_object():
|
||||
if isinstance(self, ModelsRoutingTable):
|
||||
return ("Inference", "model")
|
||||
|
@ -192,7 +192,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||
|
||||
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||
# Get from disk registry
|
||||
obj = await self.dist_registry.get(type, identifier)
|
||||
if not obj:
|
||||
|
@ -236,7 +236,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
await self.dist_registry.register(obj)
|
||||
return obj
|
||||
|
||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||
|
||||
|
@ -277,10 +277,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
provider_model_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model:
|
||||
if provider_model_id is None:
|
||||
provider_model_id = model_id
|
||||
|
@ -328,9 +328,9 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
provider_shield_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
provider_shield_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Shield:
|
||||
if provider_shield_id is None:
|
||||
provider_shield_id = shield_id
|
||||
|
@ -368,9 +368,9 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
self,
|
||||
vector_db_id: str,
|
||||
embedding_model: str,
|
||||
embedding_dimension: Optional[int] = 384,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_vector_db_id: Optional[str] = None,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorDB:
|
||||
if provider_vector_db_id is None:
|
||||
provider_vector_db_id = vector_db_id
|
||||
|
@ -423,8 +423,8 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
self,
|
||||
purpose: DatasetPurpose,
|
||||
source: DataSource,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
dataset_id: Optional[str] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
dataset_id: str | None = None,
|
||||
) -> Dataset:
|
||||
if isinstance(source, dict):
|
||||
if source["type"] == "uri":
|
||||
|
@ -489,9 +489,9 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
scoring_fn_id: str,
|
||||
description: str,
|
||||
return_type: ParamType,
|
||||
provider_scoring_fn_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[ScoringFnParams] = None,
|
||||
provider_scoring_fn_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: ScoringFnParams | None = None,
|
||||
) -> None:
|
||||
if provider_scoring_fn_id is None:
|
||||
provider_scoring_fn_id = scoring_fn_id
|
||||
|
@ -528,10 +528,10 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
self,
|
||||
benchmark_id: str,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
provider_benchmark_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
scoring_functions: list[str],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
provider_benchmark_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
) -> None:
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
@ -556,7 +556,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
|
||||
|
||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
tools = await self.get_all_with_type("tool")
|
||||
if toolgroup_id:
|
||||
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
||||
|
@ -578,8 +578,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
self,
|
||||
toolgroup_id: str,
|
||||
provider_id: str,
|
||||
mcp_endpoint: Optional[URL] = None,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
args: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
tools = []
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
|
@ -22,7 +21,7 @@ logger = get_logger(name=__name__, category="auth")
|
|||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = Field(
|
||||
access_attributes: AccessAttributes | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Structured user attributes for attribute-based access control.
|
||||
|
@ -44,7 +43,7 @@ class AuthResponse(BaseModel):
|
|||
""",
|
||||
)
|
||||
|
||||
message: Optional[str] = Field(
|
||||
message: str | None = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
||||
|
@ -52,9 +51,9 @@ class AuthResponse(BaseModel):
|
|||
class AuthRequestContext(BaseModel):
|
||||
path: str = Field(description="The path of the request being authenticated")
|
||||
|
||||
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
|
||||
params: Dict[str, List[str]] = Field(
|
||||
params: dict[str, list[str]] = Field(
|
||||
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||
)
|
||||
|
||||
|
@ -76,14 +75,14 @@ class AuthProviderConfig(BaseModel):
|
|||
"""Base configuration for authentication providers."""
|
||||
|
||||
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
|
||||
config: Dict[str, str] = Field(..., description="Provider-specific configuration")
|
||||
config: dict[str, str] = Field(..., description="Provider-specific configuration")
|
||||
|
||||
|
||||
class AuthProvider(ABC):
|
||||
"""Abstract base class for authentication providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||
"""Validate a token and return access attributes."""
|
||||
pass
|
||||
|
||||
|
@ -96,7 +95,7 @@ class AuthProvider(ABC):
|
|||
class KubernetesAuthProvider(AuthProvider):
|
||||
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
||||
|
||||
def __init__(self, config: Dict[str, str]):
|
||||
def __init__(self, config: dict[str, str]):
|
||||
self.api_server_url = config["api_server_url"]
|
||||
self.ca_cert_path = config.get("ca_cert_path")
|
||||
self._client = None
|
||||
|
@ -120,7 +119,7 @@ class KubernetesAuthProvider(AuthProvider):
|
|||
self._client = ApiClient(configuration)
|
||||
return self._client
|
||||
|
||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||
"""Validate a Kubernetes token and return access attributes."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
|
@ -166,11 +165,11 @@ class KubernetesAuthProvider(AuthProvider):
|
|||
class CustomAuthProvider(AuthProvider):
|
||||
"""Custom authentication provider that uses an external endpoint."""
|
||||
|
||||
def __init__(self, config: Dict[str, str]):
|
||||
def __init__(self, config: dict[str, str]):
|
||||
self.endpoint = config["endpoint"]
|
||||
self._client = None
|
||||
|
||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None:
|
||||
"""Validate a token using the custom authentication endpoint."""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Authentication endpoint not configured")
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import inspect
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -29,7 +28,7 @@ def toolgroup_protocol_map():
|
|||
}
|
||||
|
||||
|
||||
def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||
def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
|
||||
apis = {}
|
||||
|
||||
protocols = api_protocol_map()
|
||||
|
|
|
@ -15,7 +15,7 @@ import warnings
|
|||
from contextlib import asynccontextmanager
|
||||
from importlib.metadata import version as parse_version
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Annotated, Any
|
||||
|
||||
import yaml
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
|
@ -24,7 +24,6 @@ from fastapi.exceptions import RequestValidationError
|
|||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
|
@ -91,7 +90,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|||
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
|
||||
|
||||
|
||||
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
|
||||
def translate_exception(exc: Exception) -> HTTPException | RequestValidationError:
|
||||
if isinstance(exc, ValidationError):
|
||||
exc = RequestValidationError(exc.errors())
|
||||
|
||||
|
@ -315,7 +314,7 @@ class ClientVersionMiddleware:
|
|||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def main(args: Optional[argparse.Namespace] = None):
|
||||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
parser.add_argument(
|
||||
|
@ -385,7 +384,7 @@ def main(args: Optional[argparse.Namespace] = None):
|
|||
raise ValueError("Either --yaml-config or --template must be provided")
|
||||
|
||||
logger_config = None
|
||||
with open(config_file, "r") as fp:
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
|
@ -517,7 +516,7 @@ def main(args: Optional[argparse.Namespace] = None):
|
|||
uvicorn.run(**uvicorn_config)
|
||||
|
||||
|
||||
def extract_path_params(route: str) -> List[str]:
|
||||
def extract_path_params(route: str) -> list[str]:
|
||||
segments = route.split("/")
|
||||
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
|
||||
# to handle path params like {param:path}
|
||||
|
|
|
@ -8,7 +8,7 @@ import importlib.resources
|
|||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
@ -90,7 +90,7 @@ RESOURCES = [
|
|||
]
|
||||
|
||||
|
||||
async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||
for rsrc, api, register_method, list_method in RESOURCES:
|
||||
objects = getattr(run_config, rsrc)
|
||||
if api not in impls:
|
||||
|
@ -197,7 +197,7 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
|||
) from e
|
||||
|
||||
|
||||
def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||
def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||
|
||||
Args:
|
||||
|
@ -220,8 +220,8 @@ def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConf
|
|||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(
|
||||
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
||||
) -> Dict[Api, Any]:
|
||||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
||||
) -> dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||
|
||||
|
@ -244,7 +244,7 @@ def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
|||
|
||||
|
||||
def run_config_from_adhoc_config_spec(
|
||||
adhoc_config_spec: str, provider_registry: Optional[ProviderRegistry] = None
|
||||
adhoc_config_spec: str, provider_registry: ProviderRegistry | None = None
|
||||
) -> StackRunConfig:
|
||||
"""
|
||||
Create an adhoc distribution from a list of API providers.
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict, List, Optional, Protocol, Tuple
|
||||
from typing import Protocol
|
||||
|
||||
import pydantic
|
||||
|
||||
|
@ -20,13 +20,13 @@ logger = get_logger(__name__, category="core")
|
|||
|
||||
|
||||
class DistributionRegistry(Protocol):
|
||||
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
|
||||
async def get_all(self) -> list[RoutableObjectWithProvider]: ...
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def get(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
|
||||
async def get(self, identifier: str) -> RoutableObjectWithProvider | None: ...
|
||||
|
||||
def get_cached(self, identifier: str) -> Optional[RoutableObjectWithProvider]: ...
|
||||
def get_cached(self, identifier: str) -> RoutableObjectWithProvider | None: ...
|
||||
|
||||
async def update(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: ...
|
||||
|
||||
|
@ -40,13 +40,13 @@ KEY_VERSION = "v8"
|
|||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||
|
||||
|
||||
def _get_registry_key_range() -> Tuple[str, str]:
|
||||
def _get_registry_key_range() -> tuple[str, str]:
|
||||
"""Returns the start and end keys for the registry range query."""
|
||||
start_key = f"{REGISTER_PREFIX}:{KEY_VERSION}"
|
||||
return start_key, f"{start_key}\xff"
|
||||
|
||||
|
||||
def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider]:
|
||||
def _parse_registry_values(values: list[str]) -> list[RoutableObjectWithProvider]:
|
||||
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
|
||||
all_objects = []
|
||||
for value in values:
|
||||
|
@ -67,16 +67,16 @@ class DiskDistributionRegistry(DistributionRegistry):
|
|||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||
def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||
# Disk registry does not have a cache
|
||||
raise NotImplementedError("Disk registry does not have a cache")
|
||||
|
||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
||||
async def get_all(self) -> list[RoutableObjectWithProvider]:
|
||||
start_key, end_key = _get_registry_key_range()
|
||||
values = await self.kvstore.range(start_key, end_key)
|
||||
return _parse_registry_values(values)
|
||||
|
||||
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||
json_str = await self.kvstore.get(KEY_FORMAT.format(type=type, identifier=identifier))
|
||||
if not json_str:
|
||||
return None
|
||||
|
@ -113,7 +113,7 @@ class DiskDistributionRegistry(DistributionRegistry):
|
|||
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||
def __init__(self, kvstore: KVStore):
|
||||
super().__init__(kvstore)
|
||||
self.cache: Dict[Tuple[str, str], RoutableObjectWithProvider] = {}
|
||||
self.cache: dict[tuple[str, str], RoutableObjectWithProvider] = {}
|
||||
self._initialized = False
|
||||
self._initialize_lock = asyncio.Lock()
|
||||
self._cache_lock = asyncio.Lock()
|
||||
|
@ -147,15 +147,15 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
|||
async def initialize(self) -> None:
|
||||
await self._ensure_initialized()
|
||||
|
||||
def get_cached(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||
def get_cached(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||
return self.cache.get((type, identifier), None)
|
||||
|
||||
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
||||
async def get_all(self) -> list[RoutableObjectWithProvider]:
|
||||
await self._ensure_initialized()
|
||||
async with self._locked_cache() as cache:
|
||||
return list(cache.values())
|
||||
|
||||
async def get(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||
await self._ensure_initialized()
|
||||
cache_key = (type, identifier)
|
||||
|
||||
|
@ -189,7 +189,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
|||
|
||||
|
||||
async def create_dist_registry(
|
||||
metadata_store: Optional[KVStoreConfig],
|
||||
metadata_store: KVStoreConfig | None,
|
||||
image_name: str,
|
||||
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
|
||||
# instantiate kvstore for storing and retrieving distribution metadata
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
|
@ -23,7 +22,7 @@ class LlamaStackApi:
|
|||
},
|
||||
)
|
||||
|
||||
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]):
|
||||
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: dict | None):
|
||||
"""Run scoring on a single row"""
|
||||
if not scoring_params:
|
||||
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
||||
|
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
|
||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def redact_sensitive_fields(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Redact sensitive information from config before printing."""
|
||||
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||
|
||||
|
@ -18,7 +18,7 @@ def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
|||
return [_redact_value(i) for i in v]
|
||||
return v
|
||||
|
||||
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _redact_dict(d: dict[str, Any]) -> dict[str, Any]:
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||
|
|
|
@ -4,14 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextvars import ContextVar
|
||||
from typing import AsyncGenerator, List, TypeVar
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def preserve_contexts_async_generator(
|
||||
gen: AsyncGenerator[T, None], context_vars: List[ContextVar]
|
||||
gen: AsyncGenerator[T, None], context_vars: list[ContextVar]
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""
|
||||
Wraps an async generator to preserve context variables across iterations.
|
||||
|
|
|
@ -8,12 +8,11 @@ import inspect
|
|||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, List, Literal, Optional, Type, Union, get_args, get_origin
|
||||
from typing import Annotated, Any, Literal, Union, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefinedType
|
||||
from typing_extensions import Annotated
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -21,7 +20,7 @@ log = logging.getLogger(__name__)
|
|||
def is_list_of_primitives(field_type):
|
||||
"""Check if a field type is a List of primitive types."""
|
||||
origin = get_origin(field_type)
|
||||
if origin is List or origin is list:
|
||||
if origin is list or origin is list:
|
||||
args = get_args(field_type)
|
||||
if len(args) == 1 and args[0] in (int, float, str, bool):
|
||||
return True
|
||||
|
@ -53,7 +52,7 @@ def get_non_none_type(field_type):
|
|||
return next(arg for arg in get_args(field_type) if arg is not type(None))
|
||||
|
||||
|
||||
def manually_validate_field(model: Type[BaseModel], field_name: str, value: Any):
|
||||
def manually_validate_field(model: type[BaseModel], field_name: str, value: Any):
|
||||
validators = model.__pydantic_decorators__.field_validators
|
||||
for _name, validator in validators.items():
|
||||
if field_name in validator.info.fields:
|
||||
|
@ -126,7 +125,7 @@ def prompt_for_discriminated_union(
|
|||
#
|
||||
# doesn't support List[nested_class] yet or Dicts of any kind. needs a bunch of
|
||||
# unit tests for coverage.
|
||||
def prompt_for_config(config_type: type[BaseModel], existing_config: Optional[BaseModel] = None) -> BaseModel:
|
||||
def prompt_for_config(config_type: type[BaseModel], existing_config: BaseModel | None = None) -> BaseModel:
|
||||
"""
|
||||
Recursively prompt the user for configuration values based on a Pydantic BaseModel.
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue