feat: Bring Your Own API (BYOA)

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-05-22 17:40:45 +02:00
parent cfee63bd0d
commit 9443cef577
No known key found for this signature in database
9 changed files with 607 additions and 28 deletions

View file

@ -4,15 +4,83 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import Enum, EnumMeta
from pydantic import BaseModel
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class DynamicApiMeta(EnumMeta):
def __new__(cls, name, bases, namespace):
# Store the original enum values
original_values = {k: v for k, v in namespace.items() if not k.startswith("_")}
# Create the enum class
cls = super().__new__(cls, name, bases, namespace)
# Store the original values for reference
cls._original_values = original_values
# Initialize _dynamic_values
cls._dynamic_values = {}
return cls
def __call__(cls, value):
try:
return super().__call__(value)
except ValueError as e:
# If the value doesn't exist, create a new enum member
# Create a new member name from the value
member_name = value.lower().replace("-", "_")
# If this value was already dynamically added, return it
if value in cls._dynamic_values:
return cls._dynamic_values[value]
# If this member name already exists in the enum, return the existing member
if member_name in cls._member_map_:
return cls._member_map_[member_name]
# Instead of creating a new member, raise ValueError to force users to use Api.add() to
# register new APIs explicitly
raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e
def __iter__(cls):
# Allow iteration over both static and dynamic members
yield from super().__iter__()
if hasattr(cls, "_dynamic_values"):
yield from cls._dynamic_values.values()
def add(cls, value):
"""
Add a new API to the enum.
Particulary useful for external APIs.
"""
member_name = value.lower().replace("-", "_")
# If this member name already exists in the enum, return it
if member_name in cls._member_map_:
return cls._member_map_[member_name]
# Create a new enum member
member = object.__new__(cls)
member._name_ = member_name
member._value_ = value
# Add it to the enum class
cls._member_map_[member_name] = member
cls._member_names_.append(member_name)
cls._member_type_ = str
# Store it in our dynamic values
cls._dynamic_values[value] = member
return member
@json_schema_type
class Api(Enum):
class Api(Enum, metaclass=DynamicApiMeta):
providers = "providers"
inference = "inference"
safety = "safety"
@ -54,3 +122,12 @@ class Error(BaseModel):
title: str
detail: str
instance: str | None = None
class ExternalApiSpec(BaseModel):
"""Specification for an external API implementation."""
module: str = Field(..., description="Python module containing the API implementation")
name: str = Field(..., description="Name of the API")
pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API")
protocol: str = Field(..., description="Name of the protocol class for the API")

View file

@ -289,6 +289,11 @@ a default SQLite store will be used.""",
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
)
external_apis_dir: Path | None = Field(
default=None,
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):

View file

@ -12,6 +12,7 @@ from typing import Any
import yaml
from pydantic import BaseModel
from llama_stack.distribution.external import load_external_apis
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
AdapterSpec,
@ -133,16 +134,29 @@ def get_provider_registry(
ValueError: If any provider spec is invalid
"""
ret: dict[Api, dict[str, ProviderSpec]] = {}
registry: dict[Api, dict[str, ProviderSpec]] = {}
for api in providable_apis():
name = api.name.lower()
logger.debug(f"Importing module {name}")
try:
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {a.provider_type: a for a in module.available_providers()}
registry[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e:
logger.warning(f"Failed to import module {name}: {e}")
# Refresh providable APIs with external APIs if any
external_apis = load_external_apis(config)
for api, api_spec in external_apis.items():
name = api_spec.name.lower()
logger.info(f"Importing external API {name} module {api_spec.module}")
try:
module = importlib.import_module(api_spec.module)
registry[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e:
raise ImportError(
f"Failed to import external API module {name}. Is the external API package installed? {e}"
) from e
# Check if config has the external_providers_dir attribute
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
@ -175,11 +189,9 @@ def get_provider_registry(
else:
spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}"
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
if provider_type_key in ret[api]:
if provider_type_key in registry[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
ret[api][provider_type_key] = spec
registry[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
@ -187,4 +199,4 @@ def get_provider_registry(
except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e
return ret
return registry

View file

@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import yaml
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core")
def load_external_apis(config=None) -> dict[Api, ExternalApiSpec]:
"""Load external API specifications from the configured directory.
Args:
config: StackRunConfig containing the external APIs directory path
Returns:
A dictionary mapping API names to their specifications
"""
if not config:
return {}
if not hasattr(config, "external_apis_dir"):
return {}
if not config.external_apis_dir:
return {}
external_apis_dir = config.external_apis_dir.expanduser().resolve()
if not external_apis_dir.is_dir():
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
return {}
logger.info(f"Loading external APIs from {external_apis_dir}")
external_apis: dict[Api, ExternalApiSpec] = {}
# Look for YAML files in the external APIs directory
for yaml_path in external_apis_dir.glob("*.yaml"):
try:
with open(yaml_path) as f:
spec_data = yaml.safe_load(f)
spec = ExternalApiSpec(**spec_data)
api = Api.add(spec.name)
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
external_apis[api] = spec
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
raise yaml_err
except Exception as e:
logger.error(f"Failed to load external API spec from {yaml_path}: {e}")
raise e
return external_apis

View file

@ -16,6 +16,7 @@ from llama_stack.apis.inspect import (
VersionInfo,
)
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.server.routes import get_all_api_routes
from llama_stack.providers.datatypes import HealthStatus
@ -42,7 +43,8 @@ class DistributionInspectImpl(Inspect):
run_config: StackRunConfig = self.config.run_config
ret = []
all_endpoints = get_all_api_routes()
external_apis = load_external_apis(run_config)
all_endpoints = get_all_api_routes(external_apis)
for api, endpoints in all_endpoints.items():
# Always include provider and inspect APIs, filter others based on run config
if api.value in ["providers", "inspect"]:

View file

@ -11,6 +11,7 @@ from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.datatypes import ExternalApiSpec
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InferenceProvider
@ -35,6 +36,7 @@ from llama_stack.distribution.datatypes import (
StackRunConfig,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
@ -59,8 +61,16 @@ class InvalidProviderError(Exception):
pass
def api_protocol_map() -> dict[Api, Any]:
return {
def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]:
"""Get a mapping of API types to their protocol classes.
Args:
external_apis: Optional dictionary of external API specifications
Returns:
Dictionary mapping API types to their protocol classes
"""
protocols = {
Api.providers: ProvidersAPI,
Api.agents: Agents,
Api.inference: Inference,
@ -83,10 +93,23 @@ def api_protocol_map() -> dict[Api, Any]:
Api.files: Files,
}
if external_apis:
for api, api_spec in external_apis.items():
try:
module = importlib.import_module(api_spec.module)
api_class = getattr(module, api_spec.protocol)
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
protocols[api] = api_class
except (ImportError, AttributeError) as e:
logger.warning(f"Failed to load external API {api_spec.name}: {e}")
return protocols
def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]:
external_apis = load_external_apis(config)
return {
**api_protocol_map(),
**api_protocol_map(external_apis),
Api.inference: InferenceProvider,
}
@ -250,7 +273,7 @@ async def instantiate_providers(
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
policy: list[AccessRule],
) -> dict:
) -> dict[Api, Any]:
"""Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {}
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
@ -356,7 +379,7 @@ async def instantiate_provider(
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
protocols = api_protocol_map_for_compliance_check()
protocols = api_protocol_map_for_compliance_check(run_config)
additional_protocols = additional_protocols_map()
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol

View file

@ -12,10 +12,9 @@ from typing import Any
from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api
EndpointFunc = Callable[..., Any]
PathParams = dict[str, str]
@ -31,10 +30,13 @@ def toolgroup_protocol_map():
}
def get_all_api_routes() -> dict[Api, list[Route]]:
def get_all_api_routes(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, list[Route]]:
apis = {}
protocols = api_protocol_map()
# Lazy import to avoid circular dependency
from llama_stack.distribution.resolver import api_protocol_map
protocols = api_protocol_map(external_apis)
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items():
routes = []
@ -73,8 +75,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
return apis
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
routes = get_all_api_routes()
def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls:
routes = get_all_api_routes(external_apis)
route_impls: RouteImpls = {}
def _convert_path_to_regex(path: str) -> str:

View file

@ -33,6 +33,7 @@ from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.external import ExternalApiSpec, load_external_apis
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.routes import (
@ -270,9 +271,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
class TracingMiddleware:
def __init__(self, app, impls):
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
self.app = app
self.impls = impls
self.external_apis = external_apis
# FastAPI built-in paths that should bypass custom routing
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
@ -289,7 +291,7 @@ class TracingMiddleware:
return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
try:
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
@ -493,7 +495,9 @@ def main(args: argparse.Namespace | None = None):
else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_routes = get_all_api_routes()
# Load external APIs if configured
external_apis = load_external_apis(config)
all_routes = get_all_api_routes(external_apis)
if config.apis:
apis_to_serve = set(config.apis)
@ -512,7 +516,10 @@ def main(args: argparse.Namespace | None = None):
api = Api(api_str)
routes = all_routes[api]
impl = impls[api]
try:
impl = impls[api]
except KeyError as e:
raise ValueError(f"Could not find provider implementation for {api} API") from e
for route in routes:
if not hasattr(impl, route.name):
@ -543,7 +550,7 @@ def main(args: argparse.Namespace | None = None):
app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls)
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
import uvicorn