feat: Bring Your Own API (BYOA)

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-05-22 17:40:45 +02:00
parent cfee63bd0d
commit 9443cef577
No known key found for this signature in database
9 changed files with 607 additions and 28 deletions

View file

@ -33,6 +33,7 @@ from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.external import ExternalApiSpec, load_external_apis
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.routes import (
@ -270,9 +271,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
class TracingMiddleware:
def __init__(self, app, impls):
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
self.app = app
self.impls = impls
self.external_apis = external_apis
# FastAPI built-in paths that should bypass custom routing
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
@ -289,7 +291,7 @@ class TracingMiddleware:
return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
try:
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
@ -493,7 +495,9 @@ def main(args: argparse.Namespace | None = None):
else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_routes = get_all_api_routes()
# Load external APIs if configured
external_apis = load_external_apis(config)
all_routes = get_all_api_routes(external_apis)
if config.apis:
apis_to_serve = set(config.apis)
@ -512,7 +516,10 @@ def main(args: argparse.Namespace | None = None):
api = Api(api_str)
routes = all_routes[api]
impl = impls[api]
try:
impl = impls[api]
except KeyError as e:
raise ValueError(f"Could not find provider implementation for {api} API") from e
for route in routes:
if not hasattr(impl, route.name):
@ -543,7 +550,7 @@ def main(args: argparse.Namespace | None = None):
app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls)
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
import uvicorn