feat: Bring Your Own API (BYOA)

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-05-22 17:40:45 +02:00
parent cfee63bd0d
commit 9443cef577
No known key found for this signature in database
9 changed files with 607 additions and 28 deletions

View file

@ -0,0 +1,392 @@
# External APIs
Llama Stack supports external APIs that live outside of the main codebase. This allows you to:
- Create and maintain your own APIs independently
- Share APIs with others without contributing to the main codebase
- Keep API-specific code separate from the core Llama Stack code
## Configuration
To enable external APIs, you need to configure the `external_apis_dir` in your Llama Stack configuration. This directory should contain your external API specifications:
```yaml
external_apis_dir: ~/.llama/apis.d/
```
## Directory Structure
The external APIs directory should follow this structure:
```
apis.d/
custom_api1.yaml
custom_api2.yaml
```
Each YAML file in these directories defines an API specification.
## API Specification
Here's an example of an external API specification for a weather API:
```yaml
module: weather
api_dependencies:
- inference
protocol: WeatherAPI
name: weather
pip_packages:
- llama-stack-api-weather
```
### API Specification Fields
- `module`: Python module containing the API implementation
- `protocol`: Name of the protocol class for the API
- `name`: Name of the API
- `pip_packages`: List of pip packages to install the API, typically a single package
## Required Implementation
External APIs must expose a `available_providers()` function in their module that returns a list of provider names:
```python
# llama_stack_api_weather/api.py
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.weather,
provider_type="inline::darksky",
pip_packages=[],
module="llama_stack_provider_darksky",
config_class="llama_stack_provider_darksky.DarkSkyWeatherImplConfig",
),
]
```
A Protocol class like so:
```python
# llama_stack_api_weather/api.py
from typing import Protocol
from llama_stack.schema_utils import webmethod
class WeatherAPI(Protocol):
"""
A protocol for the Weather API.
"""
@webmethod(route="/locations", method="GET")
async def get_available_locations() -> dict[str, list[str]]:
"""
Get the available locations.
"""
...
```
## Example: Custom API
Here's a complete example of creating and using a custom API:
1. First, create the API package:
```bash
mkdir -p llama-stack-api-weather
cd llama-stack-api-weather
mkdir src/llama_stack_api_weather
git init
uv init
```
2. Edit `pyproject.toml`:
```toml
[project]
name = "llama-stack-api-weather"
version = "0.1.0"
description = "Weather API for Llama Stack"
readme = "README.md"
requires-python = ">=3.10"
dependencies = ["llama-stack", "pydantic"]
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["src"]
include = ["llama_stack_api_weather", "llama_stack_api_weather.*"]
```
3. Create the initial files:
```bash
touch src/llama_stack_api_weather/__init__.py
touch src/llama_stack_api_weather/api.py
```
```python
# llama-stack-api-weather/src/llama_stack_api_weather/__init__.py
"""Weather API for Llama Stack."""
from .api import WeatherAPI, available_providers
__all__ = ["WeatherAPI", "available_providers"]
```
4. Create the API implementation:
```python
# llama-stack-api-weather/src/llama_stack_api_weather/weather.py
from typing import Protocol
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
ProviderSpec,
RemoteProviderSpec,
)
from llama_stack.schema_utils import webmethod
def available_providers() -> list[ProviderSpec]:
return [
RemoteProviderSpec(
api=Api.weather,
provider_type="remote::kaze",
config_class="llama_stack_provider_kaze.KazeProviderConfig",
adapter=AdapterSpec(
adapter_type="kaze",
module="llama_stack_provider_kaze",
pip_packages=["llama_stack_provider_kaze"],
config_class="llama_stack_provider_kaze.KazeProviderConfig",
),
),
]
class WeatherProvider(Protocol):
"""
A protocol for the Weather API.
"""
@webmethod(route="/weather/locations", method="GET")
async def get_available_locations() -> dict[str, list[str]]:
"""
Get the available locations.
"""
...
```
5. Create the API specification:
```yaml
# ~/.llama/apis.d/weather.yaml
module: llama_stack_api_weather
name: weather
pip_packages: ["llama-stack-api-weather"]
protocol: WeatherProvider
```
6. Install the API package:
```bash
uv pip install -e .
```
7. Configure Llama Stack to use external APIs:
```yaml
version: "2"
image_name: "llama-stack-api-weather"
apis:
- weather
providers: {}
external_apis_dir: ~/.llama/apis.d
```
The API will now be available at `/v1/weather/locations`.
## Example: custom provider for the weather API
1. Create the provider package:
```bash
mkdir -p llama-stack-provider-kaze
cd llama-stack-provider-kaze
uv init
```
2. Edit `pyproject.toml`:
```toml
[project]
name = "llama-stack-provider-kaze"
version = "0.1.0"
description = "Kaze weather provider for Llama Stack"
readme = "README.md"
requires-python = ">=3.10"
dependencies = ["llama-stack", "pydantic", "aiohttp"]
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["src"]
include = ["llama_stack_provider_kaze", "llama_stack_provider_kaze.*"]
```
3. Create the initial files:
```bash
touch src/llama_stack_provider_kaze/__init__.py
touch src/llama_stack_provider_kaze/kaze.py
```
4. Create the provider implementation:
Initialization function:
```python
# llama-stack-provider-kaze/src/llama_stack_provider_kaze/__init__.py
"""Kaze weather provider for Llama Stack."""
from .config import KazeProviderConfig
from .kaze import WeatherKazeAdapter
__all__ = ["KazeProviderConfig", "WeatherKazeAdapter"]
async def get_adapter_impl(config: KazeProviderConfig, _deps):
from .kaze import WeatherKazeAdapter
impl = WeatherKazeAdapter(config)
await impl.initialize()
return impl
```
Configuration:
```python
# llama-stack-provider-kaze/src/llama_stack_provider_kaze/config.py
from pydantic import BaseModel, Field
class KazeProviderConfig(BaseModel):
"""Configuration for the Kaze weather provider."""
base_url: str = Field(
"https://api.kaze.io/v1",
description="Base URL for the Kaze weather API",
)
```
Main implementation:
```python
# llama-stack-provider-kaze/src/llama_stack_provider_kaze/kaze.py
from llama_stack_api_weather.api import WeatherProvider
from .config import KazeProviderConfig
class WeatherKazeAdapter(WeatherProvider):
"""Kaze weather provider implementation."""
def __init__(
self,
config: KazeProviderConfig,
) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def get_available_locations(self) -> dict[str, list[str]]:
"""Get available weather locations."""
return {"locations": ["Paris", "Tokyo"]}
```
5. Create the provider specification:
```yaml
# ~/.llama/providers.d/remote/weather/kaze.yaml
adapter:
adapter_type: kaze
pip_packages: ["llama_stack_provider_kaze"]
config_class: llama_stack_provider_kaze.config.KazeProviderConfig
module: llama_stack_provider_kaze
optional_api_dependencies: []
```
6. Install the provider package:
```bash
uv pip install -e .
```
7. Configure Llama Stack to use the provider:
```yaml
# ~/.llama/run-byoa.yaml
version: "2"
image_name: "llama-stack-api-weather"
apis:
- weather
providers:
weather:
- provider_id: kaze
provider_type: remote::kaze
config: {}
external_apis_dir: ~/.llama/apis.d
external_providers_dir: ~/.llama/providers.d
server:
port: 8321
```
8. Run the server:
```bash
python -m llama_stack.distribution.server.server --yaml-config ~/.llama/run-byoa.yaml
```
9. Test the API:
```bash
curl -s http://127.0.0.1:8321/v1/weather/locations
{"locations":["Paris","Tokyo"]}%
```
## Best Practices
1. **Package Naming**: Use a clear and descriptive name for your API package.
2. **Version Management**: Keep your API package versioned and compatible with the Llama Stack version you're using.
3. **Dependencies**: Only include the minimum required dependencies in your API package.
4. **Documentation**: Include clear documentation in your API package about:
- Installation requirements
- Configuration options
- API endpoints and usage
- Any limitations or known issues
5. **Testing**: Include tests in your API package to ensure it works correctly with Llama Stack.
## Troubleshooting
If your external API isn't being loaded:
1. Check that the `external_apis_dir` path is correct and accessible.
2. Verify that the YAML files are properly formatted.
3. Ensure all required Python packages are installed.
4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more information using `LLAMA_STACK_LOGGING=all=debug`.
5. Verify that the API package is installed in your Python environment.

View file

@ -4,15 +4,83 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum, EnumMeta
from pydantic import BaseModel from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
class DynamicApiMeta(EnumMeta):
def __new__(cls, name, bases, namespace):
# Store the original enum values
original_values = {k: v for k, v in namespace.items() if not k.startswith("_")}
# Create the enum class
cls = super().__new__(cls, name, bases, namespace)
# Store the original values for reference
cls._original_values = original_values
# Initialize _dynamic_values
cls._dynamic_values = {}
return cls
def __call__(cls, value):
try:
return super().__call__(value)
except ValueError as e:
# If the value doesn't exist, create a new enum member
# Create a new member name from the value
member_name = value.lower().replace("-", "_")
# If this value was already dynamically added, return it
if value in cls._dynamic_values:
return cls._dynamic_values[value]
# If this member name already exists in the enum, return the existing member
if member_name in cls._member_map_:
return cls._member_map_[member_name]
# Instead of creating a new member, raise ValueError to force users to use Api.add() to
# register new APIs explicitly
raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e
def __iter__(cls):
# Allow iteration over both static and dynamic members
yield from super().__iter__()
if hasattr(cls, "_dynamic_values"):
yield from cls._dynamic_values.values()
def add(cls, value):
"""
Add a new API to the enum.
Particulary useful for external APIs.
"""
member_name = value.lower().replace("-", "_")
# If this member name already exists in the enum, return it
if member_name in cls._member_map_:
return cls._member_map_[member_name]
# Create a new enum member
member = object.__new__(cls)
member._name_ = member_name
member._value_ = value
# Add it to the enum class
cls._member_map_[member_name] = member
cls._member_names_.append(member_name)
cls._member_type_ = str
# Store it in our dynamic values
cls._dynamic_values[value] = member
return member
@json_schema_type @json_schema_type
class Api(Enum): class Api(Enum, metaclass=DynamicApiMeta):
providers = "providers" providers = "providers"
inference = "inference" inference = "inference"
safety = "safety" safety = "safety"
@ -54,3 +122,12 @@ class Error(BaseModel):
title: str title: str
detail: str detail: str
instance: str | None = None instance: str | None = None
class ExternalApiSpec(BaseModel):
"""Specification for an external API implementation."""
module: str = Field(..., description="Python module containing the API implementation")
name: str = Field(..., description="Name of the API")
pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API")
protocol: str = Field(..., description="Name of the protocol class for the API")

View file

@ -289,6 +289,11 @@ a default SQLite store will be used.""",
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.", description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
) )
external_apis_dir: Path | None = Field(
default=None,
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
@field_validator("external_providers_dir") @field_validator("external_providers_dir")
@classmethod @classmethod
def validate_external_providers_dir(cls, v): def validate_external_providers_dir(cls, v):

View file

@ -12,6 +12,7 @@ from typing import Any
import yaml import yaml
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.distribution.external import load_external_apis
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec, AdapterSpec,
@ -133,16 +134,29 @@ def get_provider_registry(
ValueError: If any provider spec is invalid ValueError: If any provider spec is invalid
""" """
ret: dict[Api, dict[str, ProviderSpec]] = {} registry: dict[Api, dict[str, ProviderSpec]] = {}
for api in providable_apis(): for api in providable_apis():
name = api.name.lower() name = api.name.lower()
logger.debug(f"Importing module {name}") logger.debug(f"Importing module {name}")
try: try:
module = importlib.import_module(f"llama_stack.providers.registry.{name}") module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {a.provider_type: a for a in module.available_providers()} registry[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e: except ImportError as e:
logger.warning(f"Failed to import module {name}: {e}") logger.warning(f"Failed to import module {name}: {e}")
# Refresh providable APIs with external APIs if any
external_apis = load_external_apis(config)
for api, api_spec in external_apis.items():
name = api_spec.name.lower()
logger.info(f"Importing external API {name} module {api_spec.module}")
try:
module = importlib.import_module(api_spec.module)
registry[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e:
raise ImportError(
f"Failed to import external API module {name}. Is the external API package installed? {e}"
) from e
# Check if config has the external_providers_dir attribute # Check if config has the external_providers_dir attribute
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir: if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir)) external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
@ -175,11 +189,9 @@ def get_provider_registry(
else: else:
spec = _load_inline_provider_spec(spec_data, api, provider_name) spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}" provider_type_key = f"inline::{provider_name}"
if provider_type_key in registry[api]:
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
if provider_type_key in ret[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}") logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
ret[api][provider_type_key] = spec registry[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}") logger.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err: except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}") logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
@ -187,4 +199,4 @@ def get_provider_registry(
except Exception as e: except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}") logger.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e raise e
return ret return registry

View file

@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import yaml
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core")
def load_external_apis(config=None) -> dict[Api, ExternalApiSpec]:
"""Load external API specifications from the configured directory.
Args:
config: StackRunConfig containing the external APIs directory path
Returns:
A dictionary mapping API names to their specifications
"""
if not config:
return {}
if not hasattr(config, "external_apis_dir"):
return {}
if not config.external_apis_dir:
return {}
external_apis_dir = config.external_apis_dir.expanduser().resolve()
if not external_apis_dir.is_dir():
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
return {}
logger.info(f"Loading external APIs from {external_apis_dir}")
external_apis: dict[Api, ExternalApiSpec] = {}
# Look for YAML files in the external APIs directory
for yaml_path in external_apis_dir.glob("*.yaml"):
try:
with open(yaml_path) as f:
spec_data = yaml.safe_load(f)
spec = ExternalApiSpec(**spec_data)
api = Api.add(spec.name)
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
external_apis[api] = spec
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
raise yaml_err
except Exception as e:
logger.error(f"Failed to load external API spec from {yaml_path}: {e}")
raise e
return external_apis

View file

@ -16,6 +16,7 @@ from llama_stack.apis.inspect import (
VersionInfo, VersionInfo,
) )
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.server.routes import get_all_api_routes from llama_stack.distribution.server.routes import get_all_api_routes
from llama_stack.providers.datatypes import HealthStatus from llama_stack.providers.datatypes import HealthStatus
@ -42,7 +43,8 @@ class DistributionInspectImpl(Inspect):
run_config: StackRunConfig = self.config.run_config run_config: StackRunConfig = self.config.run_config
ret = [] ret = []
all_endpoints = get_all_api_routes() external_apis = load_external_apis(run_config)
all_endpoints = get_all_api_routes(external_apis)
for api, endpoints in all_endpoints.items(): for api, endpoints in all_endpoints.items():
# Always include provider and inspect APIs, filter others based on run config # Always include provider and inspect APIs, filter others based on run config
if api.value in ["providers", "inspect"]: if api.value in ["providers", "inspect"]:

View file

@ -11,6 +11,7 @@ from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.datatypes import ExternalApiSpec
from llama_stack.apis.eval import Eval from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InferenceProvider from llama_stack.apis.inference import Inference, InferenceProvider
@ -35,6 +36,7 @@ from llama_stack.distribution.datatypes import (
StackRunConfig, StackRunConfig,
) )
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -59,8 +61,16 @@ class InvalidProviderError(Exception):
pass pass
def api_protocol_map() -> dict[Api, Any]: def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]:
return { """Get a mapping of API types to their protocol classes.
Args:
external_apis: Optional dictionary of external API specifications
Returns:
Dictionary mapping API types to their protocol classes
"""
protocols = {
Api.providers: ProvidersAPI, Api.providers: ProvidersAPI,
Api.agents: Agents, Api.agents: Agents,
Api.inference: Inference, Api.inference: Inference,
@ -83,10 +93,23 @@ def api_protocol_map() -> dict[Api, Any]:
Api.files: Files, Api.files: Files,
} }
if external_apis:
for api, api_spec in external_apis.items():
try:
module = importlib.import_module(api_spec.module)
api_class = getattr(module, api_spec.protocol)
def api_protocol_map_for_compliance_check() -> dict[Api, Any]: protocols[api] = api_class
except (ImportError, AttributeError) as e:
logger.warning(f"Failed to load external API {api_spec.name}: {e}")
return protocols
def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]:
external_apis = load_external_apis(config)
return { return {
**api_protocol_map(), **api_protocol_map(external_apis),
Api.inference: InferenceProvider, Api.inference: InferenceProvider,
} }
@ -250,7 +273,7 @@ async def instantiate_providers(
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
run_config: StackRunConfig, run_config: StackRunConfig,
policy: list[AccessRule], policy: list[AccessRule],
) -> dict: ) -> dict[Api, Any]:
"""Instantiates providers asynchronously while managing dependencies.""" """Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {} impls: dict[Api, Any] = {}
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis} inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
@ -356,7 +379,7 @@ async def instantiate_provider(
impl.__provider_spec__ = provider_spec impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config impl.__provider_config__ = config
protocols = api_protocol_map_for_compliance_check() protocols = api_protocol_map_for_compliance_check(run_config)
additional_protocols = additional_protocols_map() additional_protocols = additional_protocols_map()
# TODO: check compliance for special tool groups # TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol # the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol

View file

@ -12,10 +12,9 @@ from typing import Any
from aiohttp import hdrs from aiohttp import hdrs
from starlette.routing import Route from starlette.routing import Route
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api
EndpointFunc = Callable[..., Any] EndpointFunc = Callable[..., Any]
PathParams = dict[str, str] PathParams = dict[str, str]
@ -31,10 +30,13 @@ def toolgroup_protocol_map():
} }
def get_all_api_routes() -> dict[Api, list[Route]]: def get_all_api_routes(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, list[Route]]:
apis = {} apis = {}
protocols = api_protocol_map() # Lazy import to avoid circular dependency
from llama_stack.distribution.resolver import api_protocol_map
protocols = api_protocol_map(external_apis)
toolgroup_protocols = toolgroup_protocol_map() toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items(): for api, protocol in protocols.items():
routes = [] routes = []
@ -73,8 +75,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
return apis return apis
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls: def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls:
routes = get_all_api_routes() routes = get_all_api_routes(external_apis)
route_impls: RouteImpls = {} route_impls: RouteImpls = {}
def _convert_path_to_regex(path: str) -> str: def _convert_path_to_regex(path: str) -> str:

View file

@ -33,6 +33,7 @@ from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.external import ExternalApiSpec, load_external_apis
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.routes import ( from llama_stack.distribution.server.routes import (
@ -270,9 +271,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
class TracingMiddleware: class TracingMiddleware:
def __init__(self, app, impls): def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
self.app = app self.app = app
self.impls = impls self.impls = impls
self.external_apis = external_apis
# FastAPI built-in paths that should bypass custom routing # FastAPI built-in paths that should bypass custom routing
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static") self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
@ -289,7 +291,7 @@ class TracingMiddleware:
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"): if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls) self.route_impls = initialize_route_impls(self.impls, self.external_apis)
try: try:
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls) _, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
@ -493,7 +495,9 @@ def main(args: argparse.Namespace | None = None):
else: else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {})) setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_routes = get_all_api_routes() # Load external APIs if configured
external_apis = load_external_apis(config)
all_routes = get_all_api_routes(external_apis)
if config.apis: if config.apis:
apis_to_serve = set(config.apis) apis_to_serve = set(config.apis)
@ -512,7 +516,10 @@ def main(args: argparse.Namespace | None = None):
api = Api(api_str) api = Api(api_str)
routes = all_routes[api] routes = all_routes[api]
try:
impl = impls[api] impl = impls[api]
except KeyError as e:
raise ValueError(f"Could not find provider implementation for {api} API") from e
for route in routes: for route in routes:
if not hasattr(impl, route.name): if not hasattr(impl, route.name):
@ -543,7 +550,7 @@ def main(args: argparse.Namespace | None = None):
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls) app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
import uvicorn import uvicorn