mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat: Bring Your Own API (BYOA)
Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
cfee63bd0d
commit
9443cef577
9 changed files with 607 additions and 28 deletions
392
docs/source/apis/external.md
Normal file
392
docs/source/apis/external.md
Normal file
|
@ -0,0 +1,392 @@
|
||||||
|
# External APIs
|
||||||
|
|
||||||
|
Llama Stack supports external APIs that live outside of the main codebase. This allows you to:
|
||||||
|
- Create and maintain your own APIs independently
|
||||||
|
- Share APIs with others without contributing to the main codebase
|
||||||
|
- Keep API-specific code separate from the core Llama Stack code
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
To enable external APIs, you need to configure the `external_apis_dir` in your Llama Stack configuration. This directory should contain your external API specifications:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
external_apis_dir: ~/.llama/apis.d/
|
||||||
|
```
|
||||||
|
|
||||||
|
## Directory Structure
|
||||||
|
|
||||||
|
The external APIs directory should follow this structure:
|
||||||
|
|
||||||
|
```
|
||||||
|
apis.d/
|
||||||
|
custom_api1.yaml
|
||||||
|
custom_api2.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Each YAML file in these directories defines an API specification.
|
||||||
|
|
||||||
|
## API Specification
|
||||||
|
|
||||||
|
Here's an example of an external API specification for a weather API:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
module: weather
|
||||||
|
api_dependencies:
|
||||||
|
- inference
|
||||||
|
protocol: WeatherAPI
|
||||||
|
name: weather
|
||||||
|
pip_packages:
|
||||||
|
- llama-stack-api-weather
|
||||||
|
```
|
||||||
|
|
||||||
|
### API Specification Fields
|
||||||
|
|
||||||
|
- `module`: Python module containing the API implementation
|
||||||
|
- `protocol`: Name of the protocol class for the API
|
||||||
|
- `name`: Name of the API
|
||||||
|
- `pip_packages`: List of pip packages to install the API, typically a single package
|
||||||
|
|
||||||
|
## Required Implementation
|
||||||
|
|
||||||
|
External APIs must expose a `available_providers()` function in their module that returns a list of provider names:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# llama_stack_api_weather/api.py
|
||||||
|
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
def available_providers() -> list[ProviderSpec]:
|
||||||
|
return [
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.weather,
|
||||||
|
provider_type="inline::darksky",
|
||||||
|
pip_packages=[],
|
||||||
|
module="llama_stack_provider_darksky",
|
||||||
|
config_class="llama_stack_provider_darksky.DarkSkyWeatherImplConfig",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
A Protocol class like so:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# llama_stack_api_weather/api.py
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import webmethod
|
||||||
|
|
||||||
|
|
||||||
|
class WeatherAPI(Protocol):
|
||||||
|
"""
|
||||||
|
A protocol for the Weather API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@webmethod(route="/locations", method="GET")
|
||||||
|
async def get_available_locations() -> dict[str, list[str]]:
|
||||||
|
"""
|
||||||
|
Get the available locations.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example: Custom API
|
||||||
|
|
||||||
|
Here's a complete example of creating and using a custom API:
|
||||||
|
|
||||||
|
1. First, create the API package:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir -p llama-stack-api-weather
|
||||||
|
cd llama-stack-api-weather
|
||||||
|
mkdir src/llama_stack_api_weather
|
||||||
|
git init
|
||||||
|
uv init
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Edit `pyproject.toml`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[project]
|
||||||
|
name = "llama-stack-api-weather"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Weather API for Llama Stack"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = ["llama-stack", "pydantic"]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["src"]
|
||||||
|
include = ["llama_stack_api_weather", "llama_stack_api_weather.*"]
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Create the initial files:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
touch src/llama_stack_api_weather/__init__.py
|
||||||
|
touch src/llama_stack_api_weather/api.py
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# llama-stack-api-weather/src/llama_stack_api_weather/__init__.py
|
||||||
|
"""Weather API for Llama Stack."""
|
||||||
|
|
||||||
|
from .api import WeatherAPI, available_providers
|
||||||
|
|
||||||
|
__all__ = ["WeatherAPI", "available_providers"]
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Create the API implementation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# llama-stack-api-weather/src/llama_stack_api_weather/weather.py
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import (
|
||||||
|
AdapterSpec,
|
||||||
|
Api,
|
||||||
|
ProviderSpec,
|
||||||
|
RemoteProviderSpec,
|
||||||
|
)
|
||||||
|
from llama_stack.schema_utils import webmethod
|
||||||
|
|
||||||
|
|
||||||
|
def available_providers() -> list[ProviderSpec]:
|
||||||
|
return [
|
||||||
|
RemoteProviderSpec(
|
||||||
|
api=Api.weather,
|
||||||
|
provider_type="remote::kaze",
|
||||||
|
config_class="llama_stack_provider_kaze.KazeProviderConfig",
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="kaze",
|
||||||
|
module="llama_stack_provider_kaze",
|
||||||
|
pip_packages=["llama_stack_provider_kaze"],
|
||||||
|
config_class="llama_stack_provider_kaze.KazeProviderConfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class WeatherProvider(Protocol):
|
||||||
|
"""
|
||||||
|
A protocol for the Weather API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@webmethod(route="/weather/locations", method="GET")
|
||||||
|
async def get_available_locations() -> dict[str, list[str]]:
|
||||||
|
"""
|
||||||
|
Get the available locations.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
5. Create the API specification:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# ~/.llama/apis.d/weather.yaml
|
||||||
|
module: llama_stack_api_weather
|
||||||
|
name: weather
|
||||||
|
pip_packages: ["llama-stack-api-weather"]
|
||||||
|
protocol: WeatherProvider
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
6. Install the API package:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
7. Configure Llama Stack to use external APIs:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
version: "2"
|
||||||
|
image_name: "llama-stack-api-weather"
|
||||||
|
apis:
|
||||||
|
- weather
|
||||||
|
providers: {}
|
||||||
|
external_apis_dir: ~/.llama/apis.d
|
||||||
|
```
|
||||||
|
|
||||||
|
The API will now be available at `/v1/weather/locations`.
|
||||||
|
|
||||||
|
## Example: custom provider for the weather API
|
||||||
|
|
||||||
|
1. Create the provider package:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir -p llama-stack-provider-kaze
|
||||||
|
cd llama-stack-provider-kaze
|
||||||
|
uv init
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Edit `pyproject.toml`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[project]
|
||||||
|
name = "llama-stack-provider-kaze"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Kaze weather provider for Llama Stack"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = ["llama-stack", "pydantic", "aiohttp"]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["src"]
|
||||||
|
include = ["llama_stack_provider_kaze", "llama_stack_provider_kaze.*"]
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Create the initial files:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
touch src/llama_stack_provider_kaze/__init__.py
|
||||||
|
touch src/llama_stack_provider_kaze/kaze.py
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Create the provider implementation:
|
||||||
|
|
||||||
|
|
||||||
|
Initialization function:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# llama-stack-provider-kaze/src/llama_stack_provider_kaze/__init__.py
|
||||||
|
"""Kaze weather provider for Llama Stack."""
|
||||||
|
|
||||||
|
from .config import KazeProviderConfig
|
||||||
|
from .kaze import WeatherKazeAdapter
|
||||||
|
|
||||||
|
__all__ = ["KazeProviderConfig", "WeatherKazeAdapter"]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: KazeProviderConfig, _deps):
|
||||||
|
from .kaze import WeatherKazeAdapter
|
||||||
|
|
||||||
|
impl = WeatherKazeAdapter(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
```
|
||||||
|
|
||||||
|
Configuration:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# llama-stack-provider-kaze/src/llama_stack_provider_kaze/config.py
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class KazeProviderConfig(BaseModel):
|
||||||
|
"""Configuration for the Kaze weather provider."""
|
||||||
|
|
||||||
|
base_url: str = Field(
|
||||||
|
"https://api.kaze.io/v1",
|
||||||
|
description="Base URL for the Kaze weather API",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Main implementation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# llama-stack-provider-kaze/src/llama_stack_provider_kaze/kaze.py
|
||||||
|
from llama_stack_api_weather.api import WeatherProvider
|
||||||
|
|
||||||
|
from .config import KazeProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
class WeatherKazeAdapter(WeatherProvider):
|
||||||
|
"""Kaze weather provider implementation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: KazeProviderConfig,
|
||||||
|
) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_available_locations(self) -> dict[str, list[str]]:
|
||||||
|
"""Get available weather locations."""
|
||||||
|
return {"locations": ["Paris", "Tokyo"]}
|
||||||
|
```
|
||||||
|
|
||||||
|
5. Create the provider specification:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# ~/.llama/providers.d/remote/weather/kaze.yaml
|
||||||
|
adapter:
|
||||||
|
adapter_type: kaze
|
||||||
|
pip_packages: ["llama_stack_provider_kaze"]
|
||||||
|
config_class: llama_stack_provider_kaze.config.KazeProviderConfig
|
||||||
|
module: llama_stack_provider_kaze
|
||||||
|
optional_api_dependencies: []
|
||||||
|
```
|
||||||
|
|
||||||
|
6. Install the provider package:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
7. Configure Llama Stack to use the provider:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# ~/.llama/run-byoa.yaml
|
||||||
|
version: "2"
|
||||||
|
image_name: "llama-stack-api-weather"
|
||||||
|
apis:
|
||||||
|
- weather
|
||||||
|
providers:
|
||||||
|
weather:
|
||||||
|
- provider_id: kaze
|
||||||
|
provider_type: remote::kaze
|
||||||
|
config: {}
|
||||||
|
external_apis_dir: ~/.llama/apis.d
|
||||||
|
external_providers_dir: ~/.llama/providers.d
|
||||||
|
server:
|
||||||
|
port: 8321
|
||||||
|
```
|
||||||
|
|
||||||
|
8. Run the server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m llama_stack.distribution.server.server --yaml-config ~/.llama/run-byoa.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
9. Test the API:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -s http://127.0.0.1:8321/v1/weather/locations
|
||||||
|
{"locations":["Paris","Tokyo"]}%
|
||||||
|
```
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Package Naming**: Use a clear and descriptive name for your API package.
|
||||||
|
|
||||||
|
2. **Version Management**: Keep your API package versioned and compatible with the Llama Stack version you're using.
|
||||||
|
|
||||||
|
3. **Dependencies**: Only include the minimum required dependencies in your API package.
|
||||||
|
|
||||||
|
4. **Documentation**: Include clear documentation in your API package about:
|
||||||
|
- Installation requirements
|
||||||
|
- Configuration options
|
||||||
|
- API endpoints and usage
|
||||||
|
- Any limitations or known issues
|
||||||
|
|
||||||
|
5. **Testing**: Include tests in your API package to ensure it works correctly with Llama Stack.
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
If your external API isn't being loaded:
|
||||||
|
|
||||||
|
1. Check that the `external_apis_dir` path is correct and accessible.
|
||||||
|
2. Verify that the YAML files are properly formatted.
|
||||||
|
3. Ensure all required Python packages are installed.
|
||||||
|
4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more information using `LLAMA_STACK_LOGGING=all=debug`.
|
||||||
|
5. Verify that the API package is installed in your Python environment.
|
|
@ -4,15 +4,83 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum, EnumMeta
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicApiMeta(EnumMeta):
|
||||||
|
def __new__(cls, name, bases, namespace):
|
||||||
|
# Store the original enum values
|
||||||
|
original_values = {k: v for k, v in namespace.items() if not k.startswith("_")}
|
||||||
|
|
||||||
|
# Create the enum class
|
||||||
|
cls = super().__new__(cls, name, bases, namespace)
|
||||||
|
|
||||||
|
# Store the original values for reference
|
||||||
|
cls._original_values = original_values
|
||||||
|
# Initialize _dynamic_values
|
||||||
|
cls._dynamic_values = {}
|
||||||
|
|
||||||
|
return cls
|
||||||
|
|
||||||
|
def __call__(cls, value):
|
||||||
|
try:
|
||||||
|
return super().__call__(value)
|
||||||
|
except ValueError as e:
|
||||||
|
# If the value doesn't exist, create a new enum member
|
||||||
|
# Create a new member name from the value
|
||||||
|
member_name = value.lower().replace("-", "_")
|
||||||
|
|
||||||
|
# If this value was already dynamically added, return it
|
||||||
|
if value in cls._dynamic_values:
|
||||||
|
return cls._dynamic_values[value]
|
||||||
|
|
||||||
|
# If this member name already exists in the enum, return the existing member
|
||||||
|
if member_name in cls._member_map_:
|
||||||
|
return cls._member_map_[member_name]
|
||||||
|
|
||||||
|
# Instead of creating a new member, raise ValueError to force users to use Api.add() to
|
||||||
|
# register new APIs explicitly
|
||||||
|
raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e
|
||||||
|
|
||||||
|
def __iter__(cls):
|
||||||
|
# Allow iteration over both static and dynamic members
|
||||||
|
yield from super().__iter__()
|
||||||
|
if hasattr(cls, "_dynamic_values"):
|
||||||
|
yield from cls._dynamic_values.values()
|
||||||
|
|
||||||
|
def add(cls, value):
|
||||||
|
"""
|
||||||
|
Add a new API to the enum.
|
||||||
|
Particulary useful for external APIs.
|
||||||
|
"""
|
||||||
|
member_name = value.lower().replace("-", "_")
|
||||||
|
|
||||||
|
# If this member name already exists in the enum, return it
|
||||||
|
if member_name in cls._member_map_:
|
||||||
|
return cls._member_map_[member_name]
|
||||||
|
|
||||||
|
# Create a new enum member
|
||||||
|
member = object.__new__(cls)
|
||||||
|
member._name_ = member_name
|
||||||
|
member._value_ = value
|
||||||
|
|
||||||
|
# Add it to the enum class
|
||||||
|
cls._member_map_[member_name] = member
|
||||||
|
cls._member_names_.append(member_name)
|
||||||
|
cls._member_type_ = str
|
||||||
|
|
||||||
|
# Store it in our dynamic values
|
||||||
|
cls._dynamic_values[value] = member
|
||||||
|
|
||||||
|
return member
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Api(Enum):
|
class Api(Enum, metaclass=DynamicApiMeta):
|
||||||
providers = "providers"
|
providers = "providers"
|
||||||
inference = "inference"
|
inference = "inference"
|
||||||
safety = "safety"
|
safety = "safety"
|
||||||
|
@ -54,3 +122,12 @@ class Error(BaseModel):
|
||||||
title: str
|
title: str
|
||||||
detail: str
|
detail: str
|
||||||
instance: str | None = None
|
instance: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalApiSpec(BaseModel):
|
||||||
|
"""Specification for an external API implementation."""
|
||||||
|
|
||||||
|
module: str = Field(..., description="Python module containing the API implementation")
|
||||||
|
name: str = Field(..., description="Name of the API")
|
||||||
|
pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API")
|
||||||
|
protocol: str = Field(..., description="Name of the protocol class for the API")
|
||||||
|
|
|
@ -289,6 +289,11 @@ a default SQLite store will be used.""",
|
||||||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
external_apis_dir: Path | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("external_providers_dir")
|
@field_validator("external_providers_dir")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_external_providers_dir(cls, v):
|
def validate_external_providers_dir(cls, v):
|
||||||
|
|
|
@ -12,6 +12,7 @@ from typing import Any
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.distribution.external import load_external_apis
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
AdapterSpec,
|
AdapterSpec,
|
||||||
|
@ -133,16 +134,29 @@ def get_provider_registry(
|
||||||
ValueError: If any provider spec is invalid
|
ValueError: If any provider spec is invalid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ret: dict[Api, dict[str, ProviderSpec]] = {}
|
registry: dict[Api, dict[str, ProviderSpec]] = {}
|
||||||
for api in providable_apis():
|
for api in providable_apis():
|
||||||
name = api.name.lower()
|
name = api.name.lower()
|
||||||
logger.debug(f"Importing module {name}")
|
logger.debug(f"Importing module {name}")
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Failed to import module {name}: {e}")
|
logger.warning(f"Failed to import module {name}: {e}")
|
||||||
|
|
||||||
|
# Refresh providable APIs with external APIs if any
|
||||||
|
external_apis = load_external_apis(config)
|
||||||
|
for api, api_spec in external_apis.items():
|
||||||
|
name = api_spec.name.lower()
|
||||||
|
logger.info(f"Importing external API {name} module {api_spec.module}")
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(api_spec.module)
|
||||||
|
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
f"Failed to import external API module {name}. Is the external API package installed? {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
# Check if config has the external_providers_dir attribute
|
# Check if config has the external_providers_dir attribute
|
||||||
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
||||||
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
||||||
|
@ -175,11 +189,9 @@ def get_provider_registry(
|
||||||
else:
|
else:
|
||||||
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
||||||
provider_type_key = f"inline::{provider_name}"
|
provider_type_key = f"inline::{provider_name}"
|
||||||
|
if provider_type_key in registry[api]:
|
||||||
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
|
||||||
if provider_type_key in ret[api]:
|
|
||||||
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||||
ret[api][provider_type_key] = spec
|
registry[api][provider_type_key] = spec
|
||||||
logger.info(f"Successfully loaded external provider {provider_type_key}")
|
logger.info(f"Successfully loaded external provider {provider_type_key}")
|
||||||
except yaml.YAMLError as yaml_err:
|
except yaml.YAMLError as yaml_err:
|
||||||
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||||
|
@ -187,4 +199,4 @@ def get_provider_registry(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
||||||
raise e
|
raise e
|
||||||
return ret
|
return registry
|
||||||
|
|
59
llama_stack/distribution/external.py
Normal file
59
llama_stack/distribution/external.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
def load_external_apis(config=None) -> dict[Api, ExternalApiSpec]:
|
||||||
|
"""Load external API specifications from the configured directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: StackRunConfig containing the external APIs directory path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping API names to their specifications
|
||||||
|
"""
|
||||||
|
if not config:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if not hasattr(config, "external_apis_dir"):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if not config.external_apis_dir:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
external_apis_dir = config.external_apis_dir.expanduser().resolve()
|
||||||
|
if not external_apis_dir.is_dir():
|
||||||
|
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
logger.info(f"Loading external APIs from {external_apis_dir}")
|
||||||
|
external_apis: dict[Api, ExternalApiSpec] = {}
|
||||||
|
|
||||||
|
# Look for YAML files in the external APIs directory
|
||||||
|
for yaml_path in external_apis_dir.glob("*.yaml"):
|
||||||
|
try:
|
||||||
|
with open(yaml_path) as f:
|
||||||
|
spec_data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
spec = ExternalApiSpec(**spec_data)
|
||||||
|
api = Api.add(spec.name)
|
||||||
|
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
|
||||||
|
external_apis[api] = spec
|
||||||
|
except yaml.YAMLError as yaml_err:
|
||||||
|
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
|
||||||
|
raise yaml_err
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load external API spec from {yaml_path}: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return external_apis
|
|
@ -16,6 +16,7 @@ from llama_stack.apis.inspect import (
|
||||||
VersionInfo,
|
VersionInfo,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
|
from llama_stack.distribution.external import load_external_apis
|
||||||
from llama_stack.distribution.server.routes import get_all_api_routes
|
from llama_stack.distribution.server.routes import get_all_api_routes
|
||||||
from llama_stack.providers.datatypes import HealthStatus
|
from llama_stack.providers.datatypes import HealthStatus
|
||||||
|
|
||||||
|
@ -42,7 +43,8 @@ class DistributionInspectImpl(Inspect):
|
||||||
run_config: StackRunConfig = self.config.run_config
|
run_config: StackRunConfig = self.config.run_config
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
all_endpoints = get_all_api_routes()
|
external_apis = load_external_apis(run_config)
|
||||||
|
all_endpoints = get_all_api_routes(external_apis)
|
||||||
for api, endpoints in all_endpoints.items():
|
for api, endpoints in all_endpoints.items():
|
||||||
# Always include provider and inspect APIs, filter others based on run config
|
# Always include provider and inspect APIs, filter others based on run config
|
||||||
if api.value in ["providers", "inspect"]:
|
if api.value in ["providers", "inspect"]:
|
||||||
|
|
|
@ -11,6 +11,7 @@ from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
from llama_stack.apis.datatypes import ExternalApiSpec
|
||||||
from llama_stack.apis.eval import Eval
|
from llama_stack.apis.eval import Eval
|
||||||
from llama_stack.apis.files import Files
|
from llama_stack.apis.files import Files
|
||||||
from llama_stack.apis.inference import Inference, InferenceProvider
|
from llama_stack.apis.inference import Inference, InferenceProvider
|
||||||
|
@ -35,6 +36,7 @@ from llama_stack.distribution.datatypes import (
|
||||||
StackRunConfig,
|
StackRunConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
|
from llama_stack.distribution.external import load_external_apis
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -59,8 +61,16 @@ class InvalidProviderError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def api_protocol_map() -> dict[Api, Any]:
|
def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]:
|
||||||
return {
|
"""Get a mapping of API types to their protocol classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
external_apis: Optional dictionary of external API specifications
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping API types to their protocol classes
|
||||||
|
"""
|
||||||
|
protocols = {
|
||||||
Api.providers: ProvidersAPI,
|
Api.providers: ProvidersAPI,
|
||||||
Api.agents: Agents,
|
Api.agents: Agents,
|
||||||
Api.inference: Inference,
|
Api.inference: Inference,
|
||||||
|
@ -83,10 +93,23 @@ def api_protocol_map() -> dict[Api, Any]:
|
||||||
Api.files: Files,
|
Api.files: Files,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if external_apis:
|
||||||
|
for api, api_spec in external_apis.items():
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(api_spec.module)
|
||||||
|
api_class = getattr(module, api_spec.protocol)
|
||||||
|
|
||||||
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
|
protocols[api] = api_class
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
logger.warning(f"Failed to load external API {api_spec.name}: {e}")
|
||||||
|
|
||||||
|
return protocols
|
||||||
|
|
||||||
|
|
||||||
|
def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]:
|
||||||
|
external_apis = load_external_apis(config)
|
||||||
return {
|
return {
|
||||||
**api_protocol_map(),
|
**api_protocol_map(external_apis),
|
||||||
Api.inference: InferenceProvider,
|
Api.inference: InferenceProvider,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -250,7 +273,7 @@ async def instantiate_providers(
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
run_config: StackRunConfig,
|
run_config: StackRunConfig,
|
||||||
policy: list[AccessRule],
|
policy: list[AccessRule],
|
||||||
) -> dict:
|
) -> dict[Api, Any]:
|
||||||
"""Instantiates providers asynchronously while managing dependencies."""
|
"""Instantiates providers asynchronously while managing dependencies."""
|
||||||
impls: dict[Api, Any] = {}
|
impls: dict[Api, Any] = {}
|
||||||
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||||
|
@ -356,7 +379,7 @@ async def instantiate_provider(
|
||||||
impl.__provider_spec__ = provider_spec
|
impl.__provider_spec__ = provider_spec
|
||||||
impl.__provider_config__ = config
|
impl.__provider_config__ = config
|
||||||
|
|
||||||
protocols = api_protocol_map_for_compliance_check()
|
protocols = api_protocol_map_for_compliance_check(run_config)
|
||||||
additional_protocols = additional_protocols_map()
|
additional_protocols = additional_protocols_map()
|
||||||
# TODO: check compliance for special tool groups
|
# TODO: check compliance for special tool groups
|
||||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||||
|
|
|
@ -12,10 +12,9 @@ from typing import Any
|
||||||
from aiohttp import hdrs
|
from aiohttp import hdrs
|
||||||
from starlette.routing import Route
|
from starlette.routing import Route
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||||
from llama_stack.distribution.resolver import api_protocol_map
|
|
||||||
from llama_stack.providers.datatypes import Api
|
|
||||||
|
|
||||||
EndpointFunc = Callable[..., Any]
|
EndpointFunc = Callable[..., Any]
|
||||||
PathParams = dict[str, str]
|
PathParams = dict[str, str]
|
||||||
|
@ -31,10 +30,13 @@ def toolgroup_protocol_map():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_all_api_routes() -> dict[Api, list[Route]]:
|
def get_all_api_routes(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, list[Route]]:
|
||||||
apis = {}
|
apis = {}
|
||||||
|
|
||||||
protocols = api_protocol_map()
|
# Lazy import to avoid circular dependency
|
||||||
|
from llama_stack.distribution.resolver import api_protocol_map
|
||||||
|
|
||||||
|
protocols = api_protocol_map(external_apis)
|
||||||
toolgroup_protocols = toolgroup_protocol_map()
|
toolgroup_protocols = toolgroup_protocol_map()
|
||||||
for api, protocol in protocols.items():
|
for api, protocol in protocols.items():
|
||||||
routes = []
|
routes = []
|
||||||
|
@ -73,8 +75,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
|
||||||
return apis
|
return apis
|
||||||
|
|
||||||
|
|
||||||
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls:
|
||||||
routes = get_all_api_routes()
|
routes = get_all_api_routes(external_apis)
|
||||||
route_impls: RouteImpls = {}
|
route_impls: RouteImpls = {}
|
||||||
|
|
||||||
def _convert_path_to_regex(path: str) -> str:
|
def _convert_path_to_regex(path: str) -> str:
|
||||||
|
|
|
@ -33,6 +33,7 @@ from pydantic import BaseModel, ValidationError
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
|
from llama_stack.distribution.external import ExternalApiSpec, load_external_apis
|
||||||
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
||||||
from llama_stack.distribution.resolver import InvalidProviderError
|
from llama_stack.distribution.resolver import InvalidProviderError
|
||||||
from llama_stack.distribution.server.routes import (
|
from llama_stack.distribution.server.routes import (
|
||||||
|
@ -270,9 +271,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||||
|
|
||||||
|
|
||||||
class TracingMiddleware:
|
class TracingMiddleware:
|
||||||
def __init__(self, app, impls):
|
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.impls = impls
|
self.impls = impls
|
||||||
|
self.external_apis = external_apis
|
||||||
# FastAPI built-in paths that should bypass custom routing
|
# FastAPI built-in paths that should bypass custom routing
|
||||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||||
|
|
||||||
|
@ -289,7 +291,7 @@ class TracingMiddleware:
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
if not hasattr(self, "route_impls"):
|
if not hasattr(self, "route_impls"):
|
||||||
self.route_impls = initialize_route_impls(self.impls)
|
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
|
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
|
||||||
|
@ -493,7 +495,9 @@ def main(args: argparse.Namespace | None = None):
|
||||||
else:
|
else:
|
||||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||||
|
|
||||||
all_routes = get_all_api_routes()
|
# Load external APIs if configured
|
||||||
|
external_apis = load_external_apis(config)
|
||||||
|
all_routes = get_all_api_routes(external_apis)
|
||||||
|
|
||||||
if config.apis:
|
if config.apis:
|
||||||
apis_to_serve = set(config.apis)
|
apis_to_serve = set(config.apis)
|
||||||
|
@ -512,7 +516,10 @@ def main(args: argparse.Namespace | None = None):
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
|
||||||
routes = all_routes[api]
|
routes = all_routes[api]
|
||||||
|
try:
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
|
except KeyError as e:
|
||||||
|
raise ValueError(f"Could not find provider implementation for {api} API") from e
|
||||||
|
|
||||||
for route in routes:
|
for route in routes:
|
||||||
if not hasattr(impl, route.name):
|
if not hasattr(impl, route.name):
|
||||||
|
@ -543,7 +550,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
app.exception_handler(Exception)(global_exception_handler)
|
app.exception_handler(Exception)(global_exception_handler)
|
||||||
|
|
||||||
app.__llama_stack_impls__ = impls
|
app.__llama_stack_impls__ = impls
|
||||||
app.add_middleware(TracingMiddleware, impls=impls)
|
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue