feat: Implement FastAPI router system (#4191)

# What does this PR do?

This commit introduces a new FastAPI router-based system for defining
API endpoints, enabling a migration path away from the legacy @webmethod
decorator system. The implementation includes router infrastructure,
migration of the Batches API as the first example, and updates to
server, OpenAPI generation, and inspection systems to support both
routing approaches.

The router infrastructure consists of a router registry system that
allows APIs to register FastAPI router factories, which are then
automatically discovered and included in the server application.
Standard error responses are centralized in router_utils to ensure
consistent OpenAPI specification generation with proper $ref references
to component responses.

The Batches API has been migrated to demonstrate the new pattern. The
protocol definition and models remain in llama_stack_api/batches,
maintaining clear separation between API contracts and server
implementation. The FastAPI router implementation lives in
llama_stack/core/server/routers/batches, following the established
pattern where API contracts are defined in llama_stack_api and server
routing logic lives in
llama_stack/core/server.

The server now checks for registered routers before falling back to the
legacy webmethod-based route discovery, ensuring backward compatibility
during the migration period. The OpenAPI generator has been updated to
handle both router-based and webmethod-based routes, correctly
extracting metadata from FastAPI route decorators and Pydantic Field
descriptions. The inspect endpoint now includes routes from both
systems, with proper filtering for deprecated routes and API levels.

Response descriptions are now explicitly defined in router decorators,
ensuring the generated OpenAPI specification matches the previous
format. Error responses use $ref references to component responses
(BadRequest400, TooManyRequests429, etc.) as required by the
specification. This is neat and will allow us to remove a lot of boiler
plate code from our generator once the migration is done.

This implementation provides a foundation for incrementally migrating
other APIs to the router system while maintaining full backward
compatibility with existing webmethod-based APIs.

Closes: https://github.com/llamastack/llama-stack/issues/4188

## Test Plan

CI, the server should start, same routes should be visible.

```
curl http://localhost:8321/v1/inspect/routes | jq '.data[] | select(.route | contains("batches"))'
```

Also:

```
 uv run pytest tests/integration/batches/ -vv --stack-config=http://localhost:8321
================================================== test session starts ==================================================
platform darwin -- Python 3.12.8, pytest-8.4.2, pluggy-1.6.0 -- /Users/leseb/Documents/AI/llama-stack/.venv/bin/python3
cachedir: .pytest_cache
metadata: {'Python': '3.12.8', 'Platform': 'macOS-26.0.1-arm64-arm-64bit', 'Packages': {'pytest': '8.4.2', 'pluggy': '1.6.0'}, 'Plugins': {'anyio': '4.9.0', 'html': '4.1.1', 'socket': '0.7.0', 'asyncio': '1.1.0', 'json-report': '1.5.0', 'timeout': '2.4.0', 'metadata': '3.1.1', 'cov': '6.2.1', 'nbval': '0.11.0'}}
rootdir: /Users/leseb/Documents/AI/llama-stack
configfile: pyproject.toml
plugins: anyio-4.9.0, html-4.1.1, socket-0.7.0, asyncio-1.1.0, json-report-1.5.0, timeout-2.4.0, metadata-3.1.1, cov-6.2.1, nbval-0.11.0
asyncio: mode=Mode.AUTO, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 24 items                                                                                                      

tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_creation_and_retrieval[None] SKIPPED [  4%]
tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_listing[None] SKIPPED               [  8%]
tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_immediate_cancellation[None] SKIPPED [ 12%]
tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_e2e_chat_completions[None] SKIPPED  [ 16%]
tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_e2e_completions[None] SKIPPED       [ 20%]
tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_invalid_endpoint[None] SKIPPED [ 25%]
tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_cancel_completed[None] SKIPPED [ 29%]
tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_missing_required_fields[None] SKIPPED [ 33%]
tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_invalid_completion_window[None] SKIPPED [ 37%]
tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_streaming_not_supported[None] SKIPPED [ 41%]
tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_mixed_streaming_requests[None] SKIPPED [ 45%]
tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_endpoint_mismatch[None] SKIPPED [ 50%]
tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_missing_required_body_fields[None] SKIPPED [ 54%]
tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_invalid_metadata_types[None] SKIPPED [ 58%]
tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_e2e_embeddings[None] SKIPPED        [ 62%]
tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_nonexistent_file_id PASSED [ 66%]
tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_malformed_jsonl PASSED     [ 70%]
tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_file_malformed_batch_file[empty] XFAIL [ 75%]
tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_file_malformed_batch_file[malformed] XFAIL [ 79%]
tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_retrieve_nonexistent PASSED [ 83%]
tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_cancel_nonexistent PASSED  [ 87%]
tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_error_handling_invalid_model PASSED [ 91%]
tests/integration/batches/test_batches_idempotency.py::TestBatchesIdempotencyIntegration::test_idempotent_batch_creation_successful PASSED [ 95%]
tests/integration/batches/test_batches_idempotency.py::TestBatchesIdempotencyIntegration::test_idempotency_conflict_with_different_params PASSED [100%]

================================================= slowest 10 durations ==================================================
1.01s call     tests/integration/batches/test_batches_idempotency.py::TestBatchesIdempotencyIntegration::test_idempotent_batch_creation_successful
0.21s call     tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_nonexistent_file_id
0.17s call     tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_malformed_jsonl
0.12s call     tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_error_handling_invalid_model
0.05s setup    tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_creation_and_retrieval[None]
0.02s call     tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_file_malformed_batch_file[empty]
0.01s call     tests/integration/batches/test_batches_idempotency.py::TestBatchesIdempotencyIntegration::test_idempotency_conflict_with_different_params
0.01s call     tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_file_malformed_batch_file[malformed]
0.01s call     tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_retrieve_nonexistent
0.00s call     tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_cancel_nonexistent
======================================= 7 passed, 15 skipped, 2 xfailed in 1.78s ========================================
```

---------

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-12-03 12:25:54 +01:00 committed by GitHub
parent 4237eb4aaa
commit 7f43051a63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 1095 additions and 248 deletions

View file

@ -10,8 +10,14 @@ from pydantic import BaseModel
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.external import load_external_apis
from llama_stack.core.server.fastapi_router_registry import (
_ROUTER_FACTORIES,
build_fastapi_router,
get_router_routes,
)
from llama_stack.core.server.routes import get_all_api_routes
from llama_stack_api import (
Api,
HealthInfo,
HealthStatus,
Inspect,
@ -43,6 +49,7 @@ class DistributionInspectImpl(Inspect):
run_config: StackRunConfig = self.config.run_config
# Helper function to determine if a route should be included based on api_filter
# TODO: remove this once we've migrated all APIs to FastAPI routers
def should_include_route(webmethod) -> bool:
if api_filter is None:
# Default: only non-deprecated APIs
@ -54,10 +61,62 @@ class DistributionInspectImpl(Inspect):
# Filter by API level (non-deprecated routes only)
return not webmethod.deprecated and webmethod.level == api_filter
# Helper function to get provider types for an API
def _get_provider_types(api: Api) -> list[str]:
if api.value in ["providers", "inspect"]:
return [] # These APIs don't have "real" providers they're internal to the stack
providers = run_config.providers.get(api.value, [])
return [p.provider_type for p in providers] if providers else []
# Helper function to determine if a router route should be included based on api_filter
def _should_include_router_route(route, router_prefix: str | None) -> bool:
"""Check if a router-based route should be included based on api_filter."""
# Check deprecated status
route_deprecated = getattr(route, "deprecated", False) or False
if api_filter is None:
# Default: only non-deprecated routes
return not route_deprecated
elif api_filter == "deprecated":
# Special filter: show deprecated routes regardless of their actual level
return route_deprecated
else:
# Filter by API level (non-deprecated routes only)
# Extract level from router prefix (e.g., "/v1" -> "v1")
if router_prefix:
prefix_level = router_prefix.lstrip("/")
return not route_deprecated and prefix_level == api_filter
return not route_deprecated
ret = []
external_apis = load_external_apis(run_config)
all_endpoints = get_all_api_routes(external_apis)
# Process routes from APIs with FastAPI routers
for api_name in _ROUTER_FACTORIES.keys():
api = Api(api_name)
router = build_fastapi_router(api, None) # we don't need the impl here, just the routes
if router:
router_routes = get_router_routes(router)
for route in router_routes:
if _should_include_router_route(route, router.prefix):
if route.methods is not None:
available_methods = [m for m in route.methods if m != "HEAD"]
if available_methods:
ret.append(
RouteInfo(
route=route.path,
method=available_methods[0],
provider_types=_get_provider_types(api),
)
)
# Process routes from legacy webmethod-based APIs
for api, endpoints in all_endpoints.items():
# Skip APIs that have routers (already processed above)
if api.value in _ROUTER_FACTORIES:
continue
# Always include provider and inspect APIs, filter others based on run config
if api.value in ["providers", "inspect"]:
ret.extend(

View file

@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Router utilities for FastAPI routers.
This module provides utilities to create FastAPI routers from API packages.
APIs with routers are explicitly listed here.
"""
from collections.abc import Callable
from typing import Any, cast
from fastapi import APIRouter
from fastapi.routing import APIRoute
from starlette.routing import Route
from llama_stack_api import batches
# Router factories for APIs that have FastAPI routers
# Add new APIs here as they are migrated to the router system
from llama_stack_api.datatypes import Api
_ROUTER_FACTORIES: dict[str, Callable[[Any], APIRouter]] = {
"batches": batches.fastapi_routes.create_router,
}
def has_router(api: "Api") -> bool:
"""Check if an API has a router factory.
Args:
api: The API enum value
Returns:
True if the API has a router factory, False otherwise
"""
return api.value in _ROUTER_FACTORIES
def build_fastapi_router(api: "Api", impl: Any) -> APIRouter | None:
"""Build a router for an API by combining its router factory with the implementation.
Args:
api: The API enum value
impl: The implementation instance for the API
Returns:
APIRouter if the API has a router factory, None otherwise
"""
router_factory = _ROUTER_FACTORIES.get(api.value)
if router_factory is None:
return None
# cast is safe here: all router factories in API packages are required to return APIRouter.
# If a router factory returns the wrong type, it will fail at runtime when
# app.include_router(router) is called
return cast(APIRouter, router_factory(impl))
def get_router_routes(router: APIRouter) -> list[Route]:
"""Extract routes from a FastAPI router.
Args:
router: The FastAPI router to extract routes from
Returns:
List of Route objects from the router
"""
routes = []
for route in router.routes:
# FastAPI routers use APIRoute objects, which have path and methods attributes
if isinstance(route, APIRoute):
# Combine router prefix with route path
routes.append(
Route(
path=route.path,
methods=route.methods,
name=route.name,
endpoint=route.endpoint,
)
)
elif isinstance(route, Route):
# Fallback for regular Starlette Route objects
routes.append(
Route(
path=route.path,
methods=route.methods,
name=route.name,
endpoint=route.endpoint,
)
)
return routes

View file

@ -26,6 +26,18 @@ RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
def get_all_api_routes(
external_apis: dict[Api, ExternalApiSpec] | None = None,
) -> dict[Api, list[tuple[Route, WebMethod]]]:
"""Get all API routes from webmethod-based protocols.
This function only returns routes from APIs that use the legacy @webmethod
decorator system. For APIs that have been migrated to FastAPI routers,
use the router registry (fastapi_router_registry.has_router() and fastapi_router_registry.build_fastapi_router()).
Args:
external_apis: Optional dictionary of external API specifications
Returns:
Dictionary mapping API to list of (Route, WebMethod) tuples
"""
apis = {}
protocols = api_protocol_map(external_apis)

View file

@ -44,6 +44,7 @@ from llama_stack.core.request_headers import (
request_provider_data_context,
user_from_scope,
)
from llama_stack.core.server.fastapi_router_registry import build_fastapi_router
from llama_stack.core.server.routes import get_all_api_routes
from llama_stack.core.stack import (
Stack,
@ -84,7 +85,7 @@ def create_sse_event(data: Any) -> str:
async def global_exception_handler(request: Request, exc: Exception):
traceback.print_exception(exc)
traceback.print_exception(type(exc), exc, exc.__traceback__)
http_exc = translate_exception(exc)
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
@ -454,15 +455,22 @@ def create_app() -> StackApp:
apis_to_serve.add("providers")
apis_to_serve.add("prompts")
apis_to_serve.add("conversations")
for api_str in apis_to_serve:
api = Api(api_str)
routes = all_routes[api]
try:
impl = impls[api]
except KeyError as e:
raise ValueError(f"Could not find provider implementation for {api} API") from e
# Try to discover and use a router factory from the API package
impl = impls[api]
router = build_fastapi_router(api, impl)
if router:
app.include_router(router)
logger.debug(f"Registered FastAPIrouter for {api} API")
continue
# Fall back to old webmethod-based route discovery until the migration is complete
impl = impls[api]
routes = all_routes[api]
for route, _ in routes:
if not hasattr(impl, route.name):
# ideally this should be a typing violation already
@ -488,7 +496,15 @@ def create_app() -> StackApp:
logger.debug(f"serving APIs: {apis_to_serve}")
# Register specific exception handlers before the generic Exception handler
# This prevents the re-raising behavior that causes connection resets
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(ConflictError)(global_exception_handler)
app.exception_handler(ResourceNotFoundError)(global_exception_handler)
app.exception_handler(AuthenticationRequiredError)(global_exception_handler)
app.exception_handler(AccessDeniedError)(global_exception_handler)
app.exception_handler(BadRequestError)(global_exception_handler)
# Generic Exception handler should be last
app.exception_handler(Exception)(global_exception_handler)
return app

View file

@ -11,7 +11,7 @@ import json
import time
import uuid
from io import BytesIO
from typing import Any, Literal
from typing import Any
from openai.types.batch import BatchError, Errors
from pydantic import BaseModel
@ -38,6 +38,12 @@ from llama_stack_api import (
OpenAIUserMessageParam,
ResourceNotFoundError,
)
from llama_stack_api.batches.models import (
CancelBatchRequest,
CreateBatchRequest,
ListBatchesRequest,
RetrieveBatchRequest,
)
from .config import ReferenceBatchesImplConfig
@ -140,11 +146,7 @@ class ReferenceBatchesImpl(Batches):
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
async def create_batch(
self,
input_file_id: str,
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
request: CreateBatchRequest,
) -> BatchObject:
"""
Create a new batch for processing multiple API requests.
@ -185,14 +187,14 @@ class ReferenceBatchesImpl(Batches):
# TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
if request.endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
f"Invalid endpoint: {request.endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
)
if completion_window != "24h":
if request.completion_window != "24h":
raise ValueError(
f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
f"Invalid completion_window: {request.completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
)
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
@ -200,22 +202,22 @@ class ReferenceBatchesImpl(Batches):
# For idempotent requests, use the idempotency key for the batch ID
# This ensures the same key always maps to the same batch ID,
# allowing us to detect parameter conflicts
if idempotency_key is not None:
hash_input = idempotency_key.encode("utf-8")
if request.idempotency_key is not None:
hash_input = request.idempotency_key.encode("utf-8")
hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
batch_id = f"batch_{hash_digest}"
try:
existing_batch = await self.retrieve_batch(batch_id)
existing_batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
if (
existing_batch.input_file_id != input_file_id
or existing_batch.endpoint != endpoint
or existing_batch.completion_window != completion_window
or existing_batch.metadata != metadata
existing_batch.input_file_id != request.input_file_id
or existing_batch.endpoint != request.endpoint
or existing_batch.completion_window != request.completion_window
or existing_batch.metadata != request.metadata
):
raise ConflictError(
f"Idempotency key '{idempotency_key}' was previously used with different parameters. "
f"Idempotency key '{request.idempotency_key}' was previously used with different parameters. "
"Either use a new idempotency key or ensure all parameters match the original request."
)
@ -230,12 +232,12 @@ class ReferenceBatchesImpl(Batches):
batch = BatchObject(
id=batch_id,
object="batch",
endpoint=endpoint,
input_file_id=input_file_id,
completion_window=completion_window,
endpoint=request.endpoint,
input_file_id=request.input_file_id,
completion_window=request.completion_window,
status="validating",
created_at=current_time,
metadata=metadata,
metadata=request.metadata,
)
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
@ -247,28 +249,27 @@ class ReferenceBatchesImpl(Batches):
return batch
async def cancel_batch(self, batch_id: str) -> BatchObject:
async def cancel_batch(self, request: CancelBatchRequest) -> BatchObject:
"""Cancel a batch that is in progress."""
batch = await self.retrieve_batch(batch_id)
batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=request.batch_id))
if batch.status in ["cancelled", "cancelling"]:
return batch
if batch.status in ["completed", "failed", "expired"]:
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
raise ConflictError(f"Cannot cancel batch '{request.batch_id}' with status '{batch.status}'")
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
await self._update_batch(request.batch_id, status="cancelling", cancelling_at=int(time.time()))
if batch_id in self._processing_tasks:
self._processing_tasks[batch_id].cancel()
if request.batch_id in self._processing_tasks:
self._processing_tasks[request.batch_id].cancel()
# note: task removal and status="cancelled" handled in finally block of _process_batch
return await self.retrieve_batch(batch_id)
return await self.retrieve_batch(RetrieveBatchRequest(batch_id=request.batch_id))
async def list_batches(
self,
after: str | None = None,
limit: int = 20,
request: ListBatchesRequest,
) -> ListBatchesResponse:
"""
List all batches, eventually only for the current user.
@ -285,14 +286,14 @@ class ReferenceBatchesImpl(Batches):
batches.sort(key=lambda b: b.created_at, reverse=True)
start_idx = 0
if after:
if request.after:
for i, batch in enumerate(batches):
if batch.id == after:
if batch.id == request.after:
start_idx = i + 1
break
page_batches = batches[start_idx : start_idx + limit]
has_more = (start_idx + limit) < len(batches)
page_batches = batches[start_idx : start_idx + request.limit]
has_more = (start_idx + request.limit) < len(batches)
first_id = page_batches[0].id if page_batches else None
last_id = page_batches[-1].id if page_batches else None
@ -304,11 +305,11 @@ class ReferenceBatchesImpl(Batches):
has_more=has_more,
)
async def retrieve_batch(self, batch_id: str) -> BatchObject:
async def retrieve_batch(self, request: RetrieveBatchRequest) -> BatchObject:
"""Retrieve information about a specific batch."""
batch_data = await self.kvstore.get(f"batch:{batch_id}")
batch_data = await self.kvstore.get(f"batch:{request.batch_id}")
if not batch_data:
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
raise ResourceNotFoundError(request.batch_id, "Batch", "batches.list()")
return BatchObject.model_validate_json(batch_data)
@ -316,7 +317,7 @@ class ReferenceBatchesImpl(Batches):
"""Update batch fields in kvstore."""
async with self._update_batch_lock:
try:
batch = await self.retrieve_batch(batch_id)
batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
# batch processing is async. once cancelling, only allow "cancelled" status updates
if batch.status == "cancelling" and updates.get("status") != "cancelled":
@ -536,7 +537,7 @@ class ReferenceBatchesImpl(Batches):
async def _process_batch_impl(self, batch_id: str) -> None:
"""Implementation of batch processing logic."""
errors: list[BatchError] = []
batch = await self.retrieve_batch(batch_id)
batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
errors, requests = await self._validate_input(batch)
if errors: