diff --git a/docs/source/apis/external.md b/docs/source/apis/external.md new file mode 100644 index 000000000..a1fe70bdd --- /dev/null +++ b/docs/source/apis/external.md @@ -0,0 +1,392 @@ +# External APIs + +Llama Stack supports external APIs that live outside of the main codebase. This allows you to: +- Create and maintain your own APIs independently +- Share APIs with others without contributing to the main codebase +- Keep API-specific code separate from the core Llama Stack code + +## Configuration + +To enable external APIs, you need to configure the `external_apis_dir` in your Llama Stack configuration. This directory should contain your external API specifications: + +```yaml +external_apis_dir: ~/.llama/apis.d/ +``` + +## Directory Structure + +The external APIs directory should follow this structure: + +``` +apis.d/ + custom_api1.yaml + custom_api2.yaml +``` + +Each YAML file in these directories defines an API specification. + +## API Specification + +Here's an example of an external API specification for a weather API: + +```yaml +module: weather +api_dependencies: + - inference +protocol: WeatherAPI +name: weather +pip_packages: + - llama-stack-api-weather +``` + +### API Specification Fields + +- `module`: Python module containing the API implementation +- `protocol`: Name of the protocol class for the API +- `name`: Name of the API +- `pip_packages`: List of pip packages to install the API, typically a single package + +## Required Implementation + +External APIs must expose a `available_providers()` function in their module that returns a list of provider names: + +```python +# llama_stack_api_weather/api.py +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec + + +def available_providers() -> list[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.weather, + provider_type="inline::darksky", + pip_packages=[], + module="llama_stack_provider_darksky", + config_class="llama_stack_provider_darksky.DarkSkyWeatherImplConfig", + ), + ] +``` + +A Protocol class like so: + +```python +# llama_stack_api_weather/api.py +from typing import Protocol + +from llama_stack.schema_utils import webmethod + + +class WeatherAPI(Protocol): + """ + A protocol for the Weather API. + """ + + @webmethod(route="/locations", method="GET") + async def get_available_locations() -> dict[str, list[str]]: + """ + Get the available locations. + """ + ... +``` + +## Example: Custom API + +Here's a complete example of creating and using a custom API: + +1. First, create the API package: + +```bash +mkdir -p llama-stack-api-weather +cd llama-stack-api-weather +mkdir src/llama_stack_api_weather +git init +uv init +``` + +2. Edit `pyproject.toml`: + +```toml +[project] +name = "llama-stack-api-weather" +version = "0.1.0" +description = "Weather API for Llama Stack" +readme = "README.md" +requires-python = ">=3.10" +dependencies = ["llama-stack", "pydantic"] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] +include = ["llama_stack_api_weather", "llama_stack_api_weather.*"] +``` + +3. Create the initial files: + +```bash +touch src/llama_stack_api_weather/__init__.py +touch src/llama_stack_api_weather/api.py +``` + +```python +# llama-stack-api-weather/src/llama_stack_api_weather/__init__.py +"""Weather API for Llama Stack.""" + +from .api import WeatherAPI, available_providers + +__all__ = ["WeatherAPI", "available_providers"] +``` + +4. Create the API implementation: + +```python +# llama-stack-api-weather/src/llama_stack_api_weather/weather.py +from typing import Protocol + +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + ProviderSpec, + RemoteProviderSpec, +) +from llama_stack.schema_utils import webmethod + + +def available_providers() -> list[ProviderSpec]: + return [ + RemoteProviderSpec( + api=Api.weather, + provider_type="remote::kaze", + config_class="llama_stack_provider_kaze.KazeProviderConfig", + adapter=AdapterSpec( + adapter_type="kaze", + module="llama_stack_provider_kaze", + pip_packages=["llama_stack_provider_kaze"], + config_class="llama_stack_provider_kaze.KazeProviderConfig", + ), + ), + ] + + +class WeatherProvider(Protocol): + """ + A protocol for the Weather API. + """ + + @webmethod(route="/weather/locations", method="GET") + async def get_available_locations() -> dict[str, list[str]]: + """ + Get the available locations. + """ + ... +``` + +5. Create the API specification: + +```yaml +# ~/.llama/apis.d/weather.yaml +module: llama_stack_api_weather +name: weather +pip_packages: ["llama-stack-api-weather"] +protocol: WeatherProvider + +``` + +6. Install the API package: + +```bash +uv pip install -e . +``` + +7. Configure Llama Stack to use external APIs: + +```yaml +version: "2" +image_name: "llama-stack-api-weather" +apis: + - weather +providers: {} +external_apis_dir: ~/.llama/apis.d +``` + +The API will now be available at `/v1/weather/locations`. + +## Example: custom provider for the weather API + +1. Create the provider package: + +```bash +mkdir -p llama-stack-provider-kaze +cd llama-stack-provider-kaze +uv init +``` + +2. Edit `pyproject.toml`: + +```toml +[project] +name = "llama-stack-provider-kaze" +version = "0.1.0" +description = "Kaze weather provider for Llama Stack" +readme = "README.md" +requires-python = ">=3.10" +dependencies = ["llama-stack", "pydantic", "aiohttp"] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] +include = ["llama_stack_provider_kaze", "llama_stack_provider_kaze.*"] +``` + +3. Create the initial files: + +```bash +touch src/llama_stack_provider_kaze/__init__.py +touch src/llama_stack_provider_kaze/kaze.py +``` + +4. Create the provider implementation: + + +Initialization function: + +```python +# llama-stack-provider-kaze/src/llama_stack_provider_kaze/__init__.py +"""Kaze weather provider for Llama Stack.""" + +from .config import KazeProviderConfig +from .kaze import WeatherKazeAdapter + +__all__ = ["KazeProviderConfig", "WeatherKazeAdapter"] + + +async def get_adapter_impl(config: KazeProviderConfig, _deps): + from .kaze import WeatherKazeAdapter + + impl = WeatherKazeAdapter(config) + await impl.initialize() + return impl +``` + +Configuration: + +```python +# llama-stack-provider-kaze/src/llama_stack_provider_kaze/config.py +from pydantic import BaseModel, Field + + +class KazeProviderConfig(BaseModel): + """Configuration for the Kaze weather provider.""" + + base_url: str = Field( + "https://api.kaze.io/v1", + description="Base URL for the Kaze weather API", + ) +``` + +Main implementation: + +```python +# llama-stack-provider-kaze/src/llama_stack_provider_kaze/kaze.py +from llama_stack_api_weather.api import WeatherProvider + +from .config import KazeProviderConfig + + +class WeatherKazeAdapter(WeatherProvider): + """Kaze weather provider implementation.""" + + def __init__( + self, + config: KazeProviderConfig, + ) -> None: + self.config = config + + async def initialize(self) -> None: + pass + + async def get_available_locations(self) -> dict[str, list[str]]: + """Get available weather locations.""" + return {"locations": ["Paris", "Tokyo"]} +``` + +5. Create the provider specification: + +```yaml +# ~/.llama/providers.d/remote/weather/kaze.yaml +adapter: + adapter_type: kaze + pip_packages: ["llama_stack_provider_kaze"] + config_class: llama_stack_provider_kaze.config.KazeProviderConfig + module: llama_stack_provider_kaze +optional_api_dependencies: [] +``` + +6. Install the provider package: + +```bash +uv pip install -e . +``` + +7. Configure Llama Stack to use the provider: + +```yaml +# ~/.llama/run-byoa.yaml +version: "2" +image_name: "llama-stack-api-weather" +apis: + - weather +providers: + weather: + - provider_id: kaze + provider_type: remote::kaze + config: {} +external_apis_dir: ~/.llama/apis.d +external_providers_dir: ~/.llama/providers.d +server: + port: 8321 +``` + +8. Run the server: + +```bash +python -m llama_stack.distribution.server.server --yaml-config ~/.llama/run-byoa.yaml +``` + +9. Test the API: + +```bash +curl -s http://127.0.0.1:8321/v1/weather/locations +{"locations":["Paris","Tokyo"]}% +``` + +## Best Practices + +1. **Package Naming**: Use a clear and descriptive name for your API package. + +2. **Version Management**: Keep your API package versioned and compatible with the Llama Stack version you're using. + +3. **Dependencies**: Only include the minimum required dependencies in your API package. + +4. **Documentation**: Include clear documentation in your API package about: + - Installation requirements + - Configuration options + - API endpoints and usage + - Any limitations or known issues + +5. **Testing**: Include tests in your API package to ensure it works correctly with Llama Stack. + +## Troubleshooting + +If your external API isn't being loaded: + +1. Check that the `external_apis_dir` path is correct and accessible. +2. Verify that the YAML files are properly formatted. +3. Ensure all required Python packages are installed. +4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more information using `LLAMA_STACK_LOGGING=all=debug`. +5. Verify that the API package is installed in your Python environment. diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index 63a764725..6c4ebb449 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -4,15 +4,83 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum +from enum import Enum, EnumMeta -from pydantic import BaseModel +from pydantic import BaseModel, Field from llama_stack.schema_utils import json_schema_type +class DynamicApiMeta(EnumMeta): + def __new__(cls, name, bases, namespace): + # Store the original enum values + original_values = {k: v for k, v in namespace.items() if not k.startswith("_")} + + # Create the enum class + cls = super().__new__(cls, name, bases, namespace) + + # Store the original values for reference + cls._original_values = original_values + # Initialize _dynamic_values + cls._dynamic_values = {} + + return cls + + def __call__(cls, value): + try: + return super().__call__(value) + except ValueError as e: + # If the value doesn't exist, create a new enum member + # Create a new member name from the value + member_name = value.lower().replace("-", "_") + + # If this value was already dynamically added, return it + if value in cls._dynamic_values: + return cls._dynamic_values[value] + + # If this member name already exists in the enum, return the existing member + if member_name in cls._member_map_: + return cls._member_map_[member_name] + + # Instead of creating a new member, raise ValueError to force users to use Api.add() to + # register new APIs explicitly + raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e + + def __iter__(cls): + # Allow iteration over both static and dynamic members + yield from super().__iter__() + if hasattr(cls, "_dynamic_values"): + yield from cls._dynamic_values.values() + + def add(cls, value): + """ + Add a new API to the enum. + Particulary useful for external APIs. + """ + member_name = value.lower().replace("-", "_") + + # If this member name already exists in the enum, return it + if member_name in cls._member_map_: + return cls._member_map_[member_name] + + # Create a new enum member + member = object.__new__(cls) + member._name_ = member_name + member._value_ = value + + # Add it to the enum class + cls._member_map_[member_name] = member + cls._member_names_.append(member_name) + cls._member_type_ = str + + # Store it in our dynamic values + cls._dynamic_values[value] = member + + return member + + @json_schema_type -class Api(Enum): +class Api(Enum, metaclass=DynamicApiMeta): providers = "providers" inference = "inference" safety = "safety" @@ -54,3 +122,12 @@ class Error(BaseModel): title: str detail: str instance: str | None = None + + +class ExternalApiSpec(BaseModel): + """Specification for an external API implementation.""" + + module: str = Field(..., description="Python module containing the API implementation") + name: str = Field(..., description="Name of the API") + pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API") + protocol: str = Field(..., description="Name of the protocol class for the API") diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index abc3f0065..5b4eb9733 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -289,6 +289,11 @@ a default SQLite store will be used.""", description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.", ) + external_apis_dir: Path | None = Field( + default=None, + description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.", + ) + @field_validator("external_providers_dir") @classmethod def validate_external_providers_dir(cls, v): diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index e37b2c443..1280b1d42 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -12,6 +12,7 @@ from typing import Any import yaml from pydantic import BaseModel +from llama_stack.distribution.external import load_external_apis from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( AdapterSpec, @@ -133,16 +134,29 @@ def get_provider_registry( ValueError: If any provider spec is invalid """ - ret: dict[Api, dict[str, ProviderSpec]] = {} + registry: dict[Api, dict[str, ProviderSpec]] = {} for api in providable_apis(): name = api.name.lower() logger.debug(f"Importing module {name}") try: module = importlib.import_module(f"llama_stack.providers.registry.{name}") - ret[api] = {a.provider_type: a for a in module.available_providers()} + registry[api] = {a.provider_type: a for a in module.available_providers()} except ImportError as e: logger.warning(f"Failed to import module {name}: {e}") + # Refresh providable APIs with external APIs if any + external_apis = load_external_apis(config) + for api, api_spec in external_apis.items(): + name = api_spec.name.lower() + logger.info(f"Importing external API {name} module {api_spec.module}") + try: + module = importlib.import_module(api_spec.module) + registry[api] = {a.provider_type: a for a in module.available_providers()} + except ImportError as e: + raise ImportError( + f"Failed to import external API module {name}. Is the external API package installed? {e}" + ) from e + # Check if config has the external_providers_dir attribute if config and hasattr(config, "external_providers_dir") and config.external_providers_dir: external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir)) @@ -175,11 +189,9 @@ def get_provider_registry( else: spec = _load_inline_provider_spec(spec_data, api, provider_name) provider_type_key = f"inline::{provider_name}" - - logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}") - if provider_type_key in ret[api]: + if provider_type_key in registry[api]: logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}") - ret[api][provider_type_key] = spec + registry[api][provider_type_key] = spec logger.info(f"Successfully loaded external provider {provider_type_key}") except yaml.YAMLError as yaml_err: logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}") @@ -187,4 +199,4 @@ def get_provider_registry( except Exception as e: logger.error(f"Failed to load provider spec from {spec_path}: {e}") raise e - return ret + return registry diff --git a/llama_stack/distribution/external.py b/llama_stack/distribution/external.py new file mode 100644 index 000000000..d59a01d33 --- /dev/null +++ b/llama_stack/distribution/external.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import yaml + +from llama_stack.apis.datatypes import Api, ExternalApiSpec +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="core") + + +def load_external_apis(config=None) -> dict[Api, ExternalApiSpec]: + """Load external API specifications from the configured directory. + + Args: + config: StackRunConfig containing the external APIs directory path + + Returns: + A dictionary mapping API names to their specifications + """ + if not config: + return {} + + if not hasattr(config, "external_apis_dir"): + return {} + + if not config.external_apis_dir: + return {} + + external_apis_dir = config.external_apis_dir.expanduser().resolve() + if not external_apis_dir.is_dir(): + logger.error(f"External APIs directory is not a directory: {external_apis_dir}") + return {} + + logger.info(f"Loading external APIs from {external_apis_dir}") + external_apis: dict[Api, ExternalApiSpec] = {} + + # Look for YAML files in the external APIs directory + for yaml_path in external_apis_dir.glob("*.yaml"): + try: + with open(yaml_path) as f: + spec_data = yaml.safe_load(f) + + spec = ExternalApiSpec(**spec_data) + api = Api.add(spec.name) + logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}") + external_apis[api] = spec + except yaml.YAMLError as yaml_err: + logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}") + raise yaml_err + except Exception as e: + logger.error(f"Failed to load external API spec from {yaml_path}: {e}") + raise e + + return external_apis diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index 5822070ad..7f7ab06ab 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -16,6 +16,7 @@ from llama_stack.apis.inspect import ( VersionInfo, ) from llama_stack.distribution.datatypes import StackRunConfig +from llama_stack.distribution.external import load_external_apis from llama_stack.distribution.server.routes import get_all_api_routes from llama_stack.providers.datatypes import HealthStatus @@ -42,7 +43,8 @@ class DistributionInspectImpl(Inspect): run_config: StackRunConfig = self.config.run_config ret = [] - all_endpoints = get_all_api_routes() + external_apis = load_external_apis(run_config) + all_endpoints = get_all_api_routes(external_apis) for api, endpoints in all_endpoints.items(): # Always include provider and inspect APIs, filter others based on run config if api.value in ["providers", "inspect"]: diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 3726bb3a5..c2a0b9fae 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -11,6 +11,7 @@ from llama_stack.apis.agents import Agents from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets +from llama_stack.apis.datatypes import ExternalApiSpec from llama_stack.apis.eval import Eval from llama_stack.apis.files import Files from llama_stack.apis.inference import Inference, InferenceProvider @@ -35,6 +36,7 @@ from llama_stack.distribution.datatypes import ( StackRunConfig, ) from llama_stack.distribution.distribution import builtin_automatically_routed_apis +from llama_stack.distribution.external import load_external_apis from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger @@ -59,8 +61,16 @@ class InvalidProviderError(Exception): pass -def api_protocol_map() -> dict[Api, Any]: - return { +def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]: + """Get a mapping of API types to their protocol classes. + + Args: + external_apis: Optional dictionary of external API specifications + + Returns: + Dictionary mapping API types to their protocol classes + """ + protocols = { Api.providers: ProvidersAPI, Api.agents: Agents, Api.inference: Inference, @@ -83,10 +93,23 @@ def api_protocol_map() -> dict[Api, Any]: Api.files: Files, } + if external_apis: + for api, api_spec in external_apis.items(): + try: + module = importlib.import_module(api_spec.module) + api_class = getattr(module, api_spec.protocol) -def api_protocol_map_for_compliance_check() -> dict[Api, Any]: + protocols[api] = api_class + except (ImportError, AttributeError) as e: + logger.warning(f"Failed to load external API {api_spec.name}: {e}") + + return protocols + + +def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]: + external_apis = load_external_apis(config) return { - **api_protocol_map(), + **api_protocol_map(external_apis), Api.inference: InferenceProvider, } @@ -250,7 +273,7 @@ async def instantiate_providers( dist_registry: DistributionRegistry, run_config: StackRunConfig, policy: list[AccessRule], -) -> dict: +) -> dict[Api, Any]: """Instantiates providers asynchronously while managing dependencies.""" impls: dict[Api, Any] = {} inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis} @@ -356,7 +379,7 @@ async def instantiate_provider( impl.__provider_spec__ = provider_spec impl.__provider_config__ = config - protocols = api_protocol_map_for_compliance_check() + protocols = api_protocol_map_for_compliance_check(run_config) additional_protocols = additional_protocols_map() # TODO: check compliance for special tool groups # the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol diff --git a/llama_stack/distribution/server/routes.py b/llama_stack/distribution/server/routes.py index ea66fec5a..682ef56c6 100644 --- a/llama_stack/distribution/server/routes.py +++ b/llama_stack/distribution/server/routes.py @@ -12,10 +12,9 @@ from typing import Any from aiohttp import hdrs from starlette.routing import Route +from llama_stack.apis.datatypes import Api, ExternalApiSpec from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup from llama_stack.apis.version import LLAMA_STACK_API_VERSION -from llama_stack.distribution.resolver import api_protocol_map -from llama_stack.providers.datatypes import Api EndpointFunc = Callable[..., Any] PathParams = dict[str, str] @@ -31,10 +30,13 @@ def toolgroup_protocol_map(): } -def get_all_api_routes() -> dict[Api, list[Route]]: +def get_all_api_routes(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, list[Route]]: apis = {} - protocols = api_protocol_map() + # Lazy import to avoid circular dependency + from llama_stack.distribution.resolver import api_protocol_map + + protocols = api_protocol_map(external_apis) toolgroup_protocols = toolgroup_protocol_map() for api, protocol in protocols.items(): routes = [] @@ -73,8 +75,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]: return apis -def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls: - routes = get_all_api_routes() +def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls: + routes = get_all_api_routes(external_apis) route_impls: RouteImpls = {} def _convert_path_to_regex(path: str) -> str: diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 83407a25f..9c4eb0e65 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -33,6 +33,7 @@ from pydantic import BaseModel, ValidationError from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis +from llama_stack.distribution.external import ExternalApiSpec, load_external_apis from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.server.routes import ( @@ -270,9 +271,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: class TracingMiddleware: - def __init__(self, app, impls): + def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]): self.app = app self.impls = impls + self.external_apis = external_apis # FastAPI built-in paths that should bypass custom routing self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static") @@ -289,7 +291,7 @@ class TracingMiddleware: return await self.app(scope, receive, send) if not hasattr(self, "route_impls"): - self.route_impls = initialize_route_impls(self.impls) + self.route_impls = initialize_route_impls(self.impls, self.external_apis) try: _, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls) @@ -493,7 +495,9 @@ def main(args: argparse.Namespace | None = None): else: setup_logger(TelemetryAdapter(TelemetryConfig(), {})) - all_routes = get_all_api_routes() + # Load external APIs if configured + external_apis = load_external_apis(config) + all_routes = get_all_api_routes(external_apis) if config.apis: apis_to_serve = set(config.apis) @@ -512,7 +516,10 @@ def main(args: argparse.Namespace | None = None): api = Api(api_str) routes = all_routes[api] - impl = impls[api] + try: + impl = impls[api] + except KeyError as e: + raise ValueError(f"Could not find provider implementation for {api} API") from e for route in routes: if not hasattr(impl, route.name): @@ -543,7 +550,7 @@ def main(args: argparse.Namespace | None = None): app.exception_handler(Exception)(global_exception_handler) app.__llama_stack_impls__ = impls - app.add_middleware(TracingMiddleware, impls=impls) + app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis) import uvicorn