diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index ff86e30e1..be941f652 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -37,7 +37,7 @@ paths: description: Default Response tags: - Batches - summary: List Batches + summary: List all batches for the current user. description: List all batches for the current user. operationId: list_batches_v1_batches_get parameters: @@ -76,9 +76,11 @@ paths: default: $ref: '#/components/responses/DefaultError' description: Default Response + '409': + description: 'Conflict: The idempotency key was previously used with different parameters.' tags: - Batches - summary: Create Batch + summary: Create a new batch for processing multiple API requests. description: Create a new batch for processing multiple API requests. operationId: create_batch_v1_batches_post requestBody: @@ -97,20 +99,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Retrieve Batch + summary: Retrieve information about a specific batch. description: Retrieve information about a specific batch. operationId: retrieve_batch_v1_batches__batch_id__get parameters: @@ -119,7 +121,7 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + title: Batch Id /v1/batches/{batch_id}/cancel: post: responses: @@ -130,20 +132,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Cancel Batch + summary: Cancel a batch that is in progress. description: Cancel a batch that is in progress. operationId: cancel_batch_v1_batches__batch_id__cancel_post parameters: @@ -152,7 +154,7 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + title: Batch Id /v1/chat/completions: get: responses: @@ -3950,29 +3952,35 @@ components: input_file_id: type: string title: Input File Id + description: The ID of an uploaded file containing requests for the batch. endpoint: type: string title: Endpoint + description: The endpoint to be used for all requests in the batch. completion_window: type: string const: 24h title: Completion Window + description: The time window within which the batch should be processed. metadata: anyOf: - additionalProperties: type: string type: object - type: 'null' + description: Optional metadata for the batch. idempotency_key: anyOf: - type: string - type: 'null' + description: Optional idempotency key. When provided, enables idempotent behavior. type: object required: - input_file_id - endpoint - completion_window title: CreateBatchRequest + description: Request model for creating a batch. Batch: properties: id: diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 3bc06d7d7..94b1a69a7 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -793,29 +793,35 @@ components: input_file_id: type: string title: Input File Id + description: The ID of an uploaded file containing requests for the batch. endpoint: type: string title: Endpoint + description: The endpoint to be used for all requests in the batch. completion_window: type: string const: 24h title: Completion Window + description: The time window within which the batch should be processed. metadata: anyOf: - additionalProperties: type: string type: object - type: 'null' + description: Optional metadata for the batch. idempotency_key: anyOf: - type: string - type: 'null' + description: Optional idempotency key. When provided, enables idempotent behavior. type: object required: - input_file_id - endpoint - completion_window title: CreateBatchRequest + description: Request model for creating a batch. Batch: properties: id: diff --git a/docs/static/experimental-llama-stack-spec.yaml b/docs/static/experimental-llama-stack-spec.yaml index 2b36ebf47..dfd354544 100644 --- a/docs/static/experimental-llama-stack-spec.yaml +++ b/docs/static/experimental-llama-stack-spec.yaml @@ -688,6 +688,40 @@ components: - data title: ListBatchesResponse description: Response containing a list of batch objects. + CreateBatchRequest: + properties: + input_file_id: + type: string + title: Input File Id + description: The ID of an uploaded file containing requests for the batch. + endpoint: + type: string + title: Endpoint + description: The endpoint to be used for all requests in the batch. + completion_window: + type: string + const: 24h + title: Completion Window + description: The time window within which the batch should be processed. + metadata: + anyOf: + - additionalProperties: + type: string + type: object + - type: 'null' + description: Optional metadata for the batch. + idempotency_key: + anyOf: + - type: string + - type: 'null' + description: Optional idempotency key. When provided, enables idempotent behavior. + type: object + required: + - input_file_id + - endpoint + - completion_window + title: CreateBatchRequest + description: Request model for creating a batch. Batch: properties: id: diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index a12ac342f..a736fc8f9 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -35,7 +35,7 @@ paths: description: Default Response tags: - Batches - summary: List Batches + summary: List all batches for the current user. description: List all batches for the current user. operationId: list_batches_v1_batches_get parameters: @@ -74,9 +74,11 @@ paths: default: $ref: '#/components/responses/DefaultError' description: Default Response + '409': + description: 'Conflict: The idempotency key was previously used with different parameters.' tags: - Batches - summary: Create Batch + summary: Create a new batch for processing multiple API requests. description: Create a new batch for processing multiple API requests. operationId: create_batch_v1_batches_post requestBody: @@ -95,20 +97,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Retrieve Batch + summary: Retrieve information about a specific batch. description: Retrieve information about a specific batch. operationId: retrieve_batch_v1_batches__batch_id__get parameters: @@ -117,7 +119,7 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + title: Batch Id /v1/batches/{batch_id}/cancel: post: responses: @@ -128,20 +130,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Cancel Batch + summary: Cancel a batch that is in progress. description: Cancel a batch that is in progress. operationId: cancel_batch_v1_batches__batch_id__cancel_post parameters: @@ -150,7 +152,7 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + title: Batch Id /v1/chat/completions: get: responses: @@ -2971,29 +2973,35 @@ components: input_file_id: type: string title: Input File Id + description: The ID of an uploaded file containing requests for the batch. endpoint: type: string title: Endpoint + description: The endpoint to be used for all requests in the batch. completion_window: type: string const: 24h title: Completion Window + description: The time window within which the batch should be processed. metadata: anyOf: - additionalProperties: type: string type: object - type: 'null' + description: Optional metadata for the batch. idempotency_key: anyOf: - type: string - type: 'null' + description: Optional idempotency key. When provided, enables idempotent behavior. type: object required: - input_file_id - endpoint - completion_window title: CreateBatchRequest + description: Request model for creating a batch. Batch: properties: id: diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index ff86e30e1..be941f652 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -37,7 +37,7 @@ paths: description: Default Response tags: - Batches - summary: List Batches + summary: List all batches for the current user. description: List all batches for the current user. operationId: list_batches_v1_batches_get parameters: @@ -76,9 +76,11 @@ paths: default: $ref: '#/components/responses/DefaultError' description: Default Response + '409': + description: 'Conflict: The idempotency key was previously used with different parameters.' tags: - Batches - summary: Create Batch + summary: Create a new batch for processing multiple API requests. description: Create a new batch for processing multiple API requests. operationId: create_batch_v1_batches_post requestBody: @@ -97,20 +99,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Retrieve Batch + summary: Retrieve information about a specific batch. description: Retrieve information about a specific batch. operationId: retrieve_batch_v1_batches__batch_id__get parameters: @@ -119,7 +121,7 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + title: Batch Id /v1/batches/{batch_id}/cancel: post: responses: @@ -130,20 +132,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Cancel Batch + summary: Cancel a batch that is in progress. description: Cancel a batch that is in progress. operationId: cancel_batch_v1_batches__batch_id__cancel_post parameters: @@ -152,7 +154,7 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + title: Batch Id /v1/chat/completions: get: responses: @@ -3950,29 +3952,35 @@ components: input_file_id: type: string title: Input File Id + description: The ID of an uploaded file containing requests for the batch. endpoint: type: string title: Endpoint + description: The endpoint to be used for all requests in the batch. completion_window: type: string const: 24h title: Completion Window + description: The time window within which the batch should be processed. metadata: anyOf: - additionalProperties: type: string type: object - type: 'null' + description: Optional metadata for the batch. idempotency_key: anyOf: - type: string - type: 'null' + description: Optional idempotency key. When provided, enables idempotent behavior. type: object required: - input_file_id - endpoint - completion_window title: CreateBatchRequest + description: Request model for creating a batch. Batch: properties: id: diff --git a/scripts/openapi_generator/app.py b/scripts/openapi_generator/app.py index d972889cd..48afc157d 100644 --- a/scripts/openapi_generator/app.py +++ b/scripts/openapi_generator/app.py @@ -64,7 +64,8 @@ def _get_protocol_method(api: Api, method_name: str) -> Any | None: def create_llama_stack_app() -> FastAPI: """ Create a FastAPI app that represents the Llama Stack API. - This uses the existing route discovery system to automatically find all routes. + This uses both router-based routes (for migrated APIs) and the existing + route discovery system for legacy webmethod-based routes. """ app = FastAPI( title="Llama Stack API", @@ -75,15 +76,42 @@ def create_llama_stack_app() -> FastAPI: ], ) - # Get all API routes + # Import batches router to trigger router registration + try: + from llama_stack.core.server.routers import batches # noqa: F401 + except ImportError: + pass + + # Include routers for APIs that have them registered + from llama_stack.core.server.router_registry import create_router, has_router + + def dummy_impl_getter(api: Api) -> Any: + """Dummy implementation getter for OpenAPI generation.""" + return None + + # Get all APIs that might have routers + from llama_stack.core.resolver import api_protocol_map + + protocols = api_protocol_map() + for api in protocols.keys(): + if has_router(api): + router = create_router(api, dummy_impl_getter) + if router: + app.include_router(router) + + # Get all API routes (for legacy webmethod-based routes) from llama_stack.core.server.routes import get_all_api_routes api_routes = get_all_api_routes() - # Create FastAPI routes from the discovered routes + # Create FastAPI routes from the discovered routes (skip APIs that have routers) from . import endpoints for api, routes in api_routes.items(): + # Skip APIs that have routers - they're already included above + if has_router(api): + continue + for route, webmethod in routes: # Convert the route to a FastAPI endpoint endpoints._create_fastapi_endpoint(app, route, webmethod, api) diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index 272c9d1bc..be5a26c14 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -10,8 +10,10 @@ from pydantic import BaseModel from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.external import load_external_apis +from llama_stack.core.server.router_registry import create_router, has_router from llama_stack.core.server.routes import get_all_api_routes from llama_stack_api import ( + Api, HealthInfo, HealthStatus, Inspect, @@ -57,34 +59,91 @@ class DistributionInspectImpl(Inspect): ret = [] external_apis = load_external_apis(run_config) all_endpoints = get_all_api_routes(external_apis) - for api, endpoints in all_endpoints.items(): - # Always include provider and inspect APIs, filter others based on run config + + # Helper function to get provider types for an API + def get_provider_types(api: Api) -> list[str]: if api.value in ["providers", "inspect"]: + return [] # These APIs don't have "real" providers they're internal to the stack + providers = run_config.providers.get(api.value, []) + return [p.provider_type for p in providers] if providers else [] + + # Process webmethod-based routes (legacy) + for api, endpoints in all_endpoints.items(): + # Skip APIs that have routers - they'll be processed separately + if has_router(api): + continue + + provider_types = get_provider_types(api) + # Always include provider and inspect APIs, filter others based on run config + if api.value in ["providers", "inspect"] or provider_types: ret.extend( [ RouteInfo( route=e.path, method=next(iter([m for m in e.methods if m != "HEAD"])), - provider_types=[], # These APIs don't have "real" providers - they're internal to the stack + provider_types=provider_types, ) for e, webmethod in endpoints if e.methods is not None and should_include_route(webmethod) ] ) + + # Helper function to determine if a router route should be included based on api_filter + def should_include_router_route(route, router_prefix: str | None) -> bool: + """Check if a router-based route should be included based on api_filter.""" + # Check deprecated status + route_deprecated = getattr(route, "deprecated", False) or False + + if api_filter is None: + # Default: only non-deprecated routes + return not route_deprecated + elif api_filter == "deprecated": + # Special filter: show deprecated routes regardless of their actual level + return route_deprecated else: - providers = run_config.providers.get(api.value, []) - if providers: # Only process if there are providers for this API - ret.extend( - [ - RouteInfo( - route=e.path, - method=next(iter([m for m in e.methods if m != "HEAD"])), - provider_types=[p.provider_type for p in providers], + # Filter by API level (non-deprecated routes only) + # Extract level from router prefix (e.g., "/v1" -> "v1") + if router_prefix: + prefix_level = router_prefix.lstrip("/") + return not route_deprecated and prefix_level == api_filter + return not route_deprecated + + # Process router-based routes + def dummy_impl_getter(api: Api) -> None: + """Dummy implementation getter for route inspection.""" + return None + + from llama_stack.core.resolver import api_protocol_map + + protocols = api_protocol_map(external_apis) + for api in protocols.keys(): + if not has_router(api): + continue + + router = create_router(api, dummy_impl_getter) + if not router: + continue + + provider_types = get_provider_types(api) + # Only include if there are providers (or it's a special API) + if api.value in ["providers", "inspect"] or provider_types: + router_prefix = getattr(router, "prefix", None) + for route in router.routes: + # Extract HTTP methods from the route + # FastAPI routes have methods as a set + if hasattr(route, "methods") and route.methods: + methods = {m for m in route.methods if m != "HEAD"} + if methods and should_include_router_route(route, router_prefix): + # FastAPI already combines router prefix with route path + path = route.path + + ret.append( + RouteInfo( + route=path, + method=next(iter(methods)), + provider_types=provider_types, + ) ) - for e, webmethod in endpoints - if e.methods is not None and should_include_route(webmethod) - ] - ) return ListRoutesResponse(data=ret) diff --git a/src/llama_stack/core/server/router_registry.py b/src/llama_stack/core/server/router_registry.py new file mode 100644 index 000000000..e149d1346 --- /dev/null +++ b/src/llama_stack/core/server/router_registry.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Router registry for FastAPI routers. + +This module provides a way to register FastAPI routers for APIs that have been +migrated to use explicit FastAPI routers instead of Protocol-based route discovery. +""" + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from fastapi import APIRouter + +if TYPE_CHECKING: + from llama_stack_api.datatypes import Api + +# Registry of router factory functions +# Each factory function takes a callable that returns the implementation for a given API +# and returns an APIRouter +# Use string keys to avoid circular imports +_router_factories: dict[str, Callable[[Callable[["Api"], Any]], APIRouter]] = {} + + +def register_router(api: "Api", router_factory: Callable[[Callable[["Api"], Any]], APIRouter]) -> None: + """Register a router factory for an API. + + Args: + api: The API enum value + router_factory: A function that takes an impl_getter function and returns an APIRouter + """ + _router_factories[api.value] = router_factory # type: ignore[attr-defined] + + +def has_router(api: "Api") -> bool: + """Check if an API has a registered router. + + Args: + api: The API enum value + + Returns: + True if a router factory is registered for this API + """ + return api.value in _router_factories # type: ignore[attr-defined] + + +def create_router(api: "Api", impl_getter: Callable[["Api"], Any]) -> APIRouter | None: + """Create a router for an API if one is registered. + + Args: + api: The API enum value + impl_getter: Function that returns the implementation for a given API + + Returns: + APIRouter if registered, None otherwise + """ + api_value = api.value # type: ignore[attr-defined] + if api_value not in _router_factories: + return None + + return _router_factories[api_value](impl_getter) diff --git a/src/llama_stack/core/server/router_utils.py b/src/llama_stack/core/server/router_utils.py new file mode 100644 index 000000000..1c508af76 --- /dev/null +++ b/src/llama_stack/core/server/router_utils.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Utilities for creating FastAPI routers with standard error responses.""" + +standard_responses = { + 400: {"$ref": "#/components/responses/BadRequest400"}, + 429: {"$ref": "#/components/responses/TooManyRequests429"}, + 500: {"$ref": "#/components/responses/InternalServerError500"}, + "default": {"$ref": "#/components/responses/DefaultError"}, +} diff --git a/src/llama_stack/core/server/routers/__init__.py b/src/llama_stack/core/server/routers/__init__.py new file mode 100644 index 000000000..213cb75c8 --- /dev/null +++ b/src/llama_stack/core/server/routers/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""FastAPI router implementations for server endpoints. + +This package contains FastAPI router implementations that define the HTTP +endpoints for each API. The API contracts (protocols and models) are defined +in llama_stack_api, while the server routing implementation lives here. +""" diff --git a/src/llama_stack/core/server/routers/batches.py b/src/llama_stack/core/server/routers/batches.py new file mode 100644 index 000000000..fb7f8ebfa --- /dev/null +++ b/src/llama_stack/core/server/routers/batches.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""FastAPI router for the Batches API. + +This module defines the FastAPI router for the Batches API using standard +FastAPI route decorators instead of Protocol-based route discovery. +""" + +from collections.abc import Callable +from typing import Annotated + +from fastapi import APIRouter, Body, Depends + +from llama_stack.core.server.router_registry import register_router +from llama_stack.core.server.router_utils import standard_responses +from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse +from llama_stack_api.batches.models import CreateBatchRequest +from llama_stack_api.datatypes import Api +from llama_stack_api.version import LLAMA_STACK_API_V1 + + +def create_batches_router(impl_getter: Callable[[Api], Batches]) -> APIRouter: + """Create a FastAPI router for the Batches API. + + Args: + impl_getter: Function that returns the Batches implementation for the batches API + + Returns: + APIRouter configured for the Batches API + """ + router = APIRouter( + prefix=f"/{LLAMA_STACK_API_V1}", + tags=["Batches"], + responses=standard_responses, + ) + + def get_batch_service() -> Batches: + """Dependency function to get the batch service implementation.""" + return impl_getter(Api.batches) + + @router.post( + "/batches", + response_model=BatchObject, + summary="Create a new batch for processing multiple API requests.", + description="Create a new batch for processing multiple API requests.", + responses={ + 200: {"description": "The created batch object."}, + 409: {"description": "Conflict: The idempotency key was previously used with different parameters."}, + }, + ) + async def create_batch( + request: Annotated[CreateBatchRequest, Body(...)], + svc: Annotated[Batches, Depends(get_batch_service)], + ) -> BatchObject: + """Create a new batch.""" + return await svc.create_batch( + input_file_id=request.input_file_id, + endpoint=request.endpoint, + completion_window=request.completion_window, + metadata=request.metadata, + idempotency_key=request.idempotency_key, + ) + + @router.get( + "/batches/{batch_id}", + response_model=BatchObject, + summary="Retrieve information about a specific batch.", + description="Retrieve information about a specific batch.", + responses={ + 200: {"description": "The batch object."}, + }, + ) + async def retrieve_batch( + batch_id: str, + svc: Annotated[Batches, Depends(get_batch_service)], + ) -> BatchObject: + """Retrieve information about a specific batch.""" + return await svc.retrieve_batch(batch_id) + + @router.post( + "/batches/{batch_id}/cancel", + response_model=BatchObject, + summary="Cancel a batch that is in progress.", + description="Cancel a batch that is in progress.", + responses={ + 200: {"description": "The updated batch object."}, + }, + ) + async def cancel_batch( + batch_id: str, + svc: Annotated[Batches, Depends(get_batch_service)], + ) -> BatchObject: + """Cancel a batch that is in progress.""" + return await svc.cancel_batch(batch_id) + + @router.get( + "/batches", + response_model=ListBatchesResponse, + summary="List all batches for the current user.", + description="List all batches for the current user.", + responses={ + 200: {"description": "A list of batch objects."}, + }, + ) + async def list_batches( + svc: Annotated[Batches, Depends(get_batch_service)], + after: str | None = None, + limit: int = 20, + ) -> ListBatchesResponse: + """List all batches for the current user.""" + return await svc.list_batches(after=after, limit=limit) + + return router + + +# Register the router factory +register_router(Api.batches, create_batches_router) diff --git a/src/llama_stack/core/server/routes.py b/src/llama_stack/core/server/routes.py index af5002565..25027267f 100644 --- a/src/llama_stack/core/server/routes.py +++ b/src/llama_stack/core/server/routes.py @@ -26,6 +26,18 @@ RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod] def get_all_api_routes( external_apis: dict[Api, ExternalApiSpec] | None = None, ) -> dict[Api, list[tuple[Route, WebMethod]]]: + """Get all API routes from webmethod-based protocols. + + This function only returns routes from APIs that use the legacy @webmethod + decorator system. For APIs that have been migrated to FastAPI routers, + use the router registry (router_registry.has_router() and router_registry.create_router()). + + Args: + external_apis: Optional dictionary of external API specifications + + Returns: + Dictionary mapping API to list of (Route, WebMethod) tuples + """ apis = {} protocols = api_protocol_map(external_apis) diff --git a/src/llama_stack/core/server/server.py b/src/llama_stack/core/server/server.py index 0d3513980..76f283f3a 100644 --- a/src/llama_stack/core/server/server.py +++ b/src/llama_stack/core/server/server.py @@ -44,6 +44,7 @@ from llama_stack.core.request_headers import ( request_provider_data_context, user_from_scope, ) +from llama_stack.core.server.router_registry import create_router, has_router from llama_stack.core.server.routes import get_all_api_routes from llama_stack.core.stack import ( Stack, @@ -87,7 +88,7 @@ def create_sse_event(data: Any) -> str: async def global_exception_handler(request: Request, exc: Exception): - traceback.print_exception(exc) + traceback.print_exception(type(exc), exc, exc.__traceback__) http_exc = translate_exception(exc) return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}) @@ -448,6 +449,14 @@ def create_app() -> StackApp: external_apis = load_external_apis(config) all_routes = get_all_api_routes(external_apis) + # Import batches router to trigger router registration + # This ensures the router is registered before we try to use it + # We will make this code better once the migration is complete + try: + from llama_stack.core.server.routers import batches # noqa: F401 + except ImportError: + pass + if config.apis: apis_to_serve = set(config.apis) else: @@ -463,41 +472,68 @@ def create_app() -> StackApp: apis_to_serve.add("providers") apis_to_serve.add("prompts") apis_to_serve.add("conversations") - for api_str in apis_to_serve: - api = Api(api_str) - routes = all_routes[api] + def impl_getter(api: Api) -> Any: + """Get the implementation for a given API.""" try: - impl = impls[api] + return impls[api] except KeyError as e: raise ValueError(f"Could not find provider implementation for {api} API") from e - for route, _ in routes: - if not hasattr(impl, route.name): - # ideally this should be a typing violation already - raise ValueError(f"Could not find method {route.name} on {impl}!") + for api_str in apis_to_serve: + api = Api(api_str) - impl_method = getattr(impl, route.name) - # Filter out HEAD method since it's automatically handled by FastAPI for GET routes - available_methods = [m for m in route.methods if m != "HEAD"] - if not available_methods: - raise ValueError(f"No methods found for {route.name} on {impl}") - method = available_methods[0] - logger.debug(f"{method} {route.path}") - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields") - getattr(app, method.lower())(route.path, response_model=None)( - create_dynamic_typed_route( - impl_method, - method.lower(), - route.path, - ) + if has_router(api): + router = create_router(api, impl_getter) + if router: + app.include_router(router) + logger.debug(f"Registered router for {api} API") + else: + logger.warning( + f"API '{api.value}' has a registered router factory but it returned None. Skipping this API." ) + else: + # Fall back to old webmethod-based route discovery until the migration is complete + routes = all_routes[api] + try: + impl = impls[api] + except KeyError as e: + raise ValueError(f"Could not find provider implementation for {api} API") from e + + for route, _ in routes: + if not hasattr(impl, route.name): + # ideally this should be a typing violation already + raise ValueError(f"Could not find method {route.name} on {impl}!") + + impl_method = getattr(impl, route.name) + # Filter out HEAD method since it's automatically handled by FastAPI for GET routes + available_methods = [m for m in route.methods if m != "HEAD"] + if not available_methods: + raise ValueError(f"No methods found for {route.name} on {impl}") + method = available_methods[0] + logger.debug(f"{method} {route.path}") + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields") + getattr(app, method.lower())(route.path, response_model=None)( + create_dynamic_typed_route( + impl_method, + method.lower(), + route.path, + ) + ) logger.debug(f"serving APIs: {apis_to_serve}") + # Register specific exception handlers before the generic Exception handler + # This prevents the re-raising behavior that causes connection resets app.exception_handler(RequestValidationError)(global_exception_handler) + app.exception_handler(ConflictError)(global_exception_handler) + app.exception_handler(ResourceNotFoundError)(global_exception_handler) + app.exception_handler(AuthenticationRequiredError)(global_exception_handler) + app.exception_handler(AccessDeniedError)(global_exception_handler) + app.exception_handler(BadRequestError)(global_exception_handler) + # Generic Exception handler should be last app.exception_handler(Exception)(global_exception_handler) if config.telemetry.enabled: diff --git a/src/llama_stack/core/server/tracing.py b/src/llama_stack/core/server/tracing.py index c4901d9b1..352badcaa 100644 --- a/src/llama_stack/core/server/tracing.py +++ b/src/llama_stack/core/server/tracing.py @@ -6,9 +6,11 @@ from aiohttp import hdrs from llama_stack.core.external import ExternalApiSpec +from llama_stack.core.server.router_registry import has_router from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.core.telemetry.tracing import end_trace, start_trace from llama_stack.log import get_logger +from llama_stack_api.datatypes import Api logger = get_logger(name=__name__, category="core::server") @@ -21,6 +23,25 @@ class TracingMiddleware: # FastAPI built-in paths that should bypass custom routing self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static") + def _is_router_based_route(self, path: str) -> bool: + """Check if a path belongs to a router-based API. + + Router-based APIs use FastAPI routers instead of the old webmethod system. + We need to check if the path matches any router-based API prefix. + """ + # Extract API name from path (e.g., /v1/batches -> batches) + # Paths are typically /v1/{api_name} or /v1/{api_name}/... + parts = path.strip("/").split("/") + if len(parts) >= 2 and parts[0].startswith("v"): + api_name = parts[1] + try: + api = Api(api_name) + return has_router(api) + except (ValueError, KeyError): + # Not a known API or not router-based + return False + return False + async def __call__(self, scope, receive, send): if scope.get("type") == "lifespan": return await self.app(scope, receive, send) @@ -33,6 +54,44 @@ class TracingMiddleware: logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}") return await self.app(scope, receive, send) + # Check if this is a router-based route - if so, pass through to FastAPI + # Router-based routes are handled by FastAPI directly, so we skip the old route lookup + # but still need to set up tracing + is_router_based = self._is_router_based_route(path) + if is_router_based: + logger.debug(f"Router-based route detected: {path}, setting up tracing") + # Set up tracing for router-based routes + trace_attributes = {"__location__": "server", "raw_path": path} + + # Extract W3C trace context headers and store as trace attributes + headers = dict(scope.get("headers", [])) + traceparent = headers.get(b"traceparent", b"").decode() + if traceparent: + trace_attributes["traceparent"] = traceparent + tracestate = headers.get(b"tracestate", b"").decode() + if tracestate: + trace_attributes["tracestate"] = tracestate + + trace_context = await start_trace(path, trace_attributes) + + async def send_with_trace_id(message): + if message["type"] == "http.response.start": + headers = message.get("headers", []) + headers.append([b"x-trace-id", str(trace_context.trace_id).encode()]) + message["headers"] = headers + await send(message) + + try: + return await self.app(scope, receive, send_with_trace_id) + finally: + # Always end trace, even if exception occurred + # FastAPI's exception handler will handle the exception and send the response + # The exception will continue to propagate for logging, which is normal + try: + await end_trace() + except Exception: + logger.exception("Error ending trace") + if not hasattr(self, "route_impls"): self.route_impls = initialize_route_impls(self.impls, self.external_apis) diff --git a/src/llama_stack_api/batches.py b/src/llama_stack_api/batches/__init__.py similarity index 52% rename from src/llama_stack_api/batches.py rename to src/llama_stack_api/batches/__init__.py index 00c47d39f..636dd0c52 100644 --- a/src/llama_stack_api/batches.py +++ b/src/llama_stack_api/batches/__init__.py @@ -4,12 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +"""Batches API protocol and models. + +This module contains the Batches protocol definition and related models. +The router implementation is in llama_stack.core.server.routers.batches. +""" + from typing import Literal, Protocol, runtime_checkable from pydantic import BaseModel, Field -from llama_stack_api.schema_utils import json_schema_type, webmethod -from llama_stack_api.version import LLAMA_STACK_API_V1 +from llama_stack_api.schema_utils import json_schema_type try: from openai.types import Batch as BatchObject @@ -43,7 +48,6 @@ class Batches(Protocol): Note: This API is currently under active development and may undergo changes. """ - @webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1) async def create_batch( self, input_file_id: str, @@ -51,46 +55,17 @@ class Batches(Protocol): completion_window: Literal["24h"], metadata: dict[str, str] | None = None, idempotency_key: str | None = None, - ) -> BatchObject: - """Create a new batch for processing multiple API requests. + ) -> BatchObject: ... - :param input_file_id: The ID of an uploaded file containing requests for the batch. - :param endpoint: The endpoint to be used for all requests in the batch. - :param completion_window: The time window within which the batch should be processed. - :param metadata: Optional metadata for the batch. - :param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior. - :returns: The created batch object. - """ - ... + async def retrieve_batch(self, batch_id: str) -> BatchObject: ... - @webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1) - async def retrieve_batch(self, batch_id: str) -> BatchObject: - """Retrieve information about a specific batch. + async def cancel_batch(self, batch_id: str) -> BatchObject: ... - :param batch_id: The ID of the batch to retrieve. - :returns: The batch object. - """ - ... - - @webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1) - async def cancel_batch(self, batch_id: str) -> BatchObject: - """Cancel a batch that is in progress. - - :param batch_id: The ID of the batch to cancel. - :returns: The updated batch object. - """ - ... - - @webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1) async def list_batches( self, after: str | None = None, limit: int = 20, - ) -> ListBatchesResponse: - """List all batches for the current user. + ) -> ListBatchesResponse: ... - :param after: A cursor for pagination; returns batches after this batch ID. - :param limit: Number of batches to return (default 20, max 100). - :returns: A list of batch objects. - """ - ... + +__all__ = ["Batches", "BatchObject", "ListBatchesResponse"] diff --git a/src/llama_stack_api/batches/models.py b/src/llama_stack_api/batches/models.py new file mode 100644 index 000000000..22e024be2 --- /dev/null +++ b/src/llama_stack_api/batches/models.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Pydantic models for Batches API requests and responses. + +This module defines the request and response models for the Batches API +using Pydantic with Field descriptions for OpenAPI schema generation. +""" + +from typing import Literal + +from pydantic import BaseModel, Field + +from llama_stack_api.batches import BatchObject, ListBatchesResponse +from llama_stack_api.schema_utils import json_schema_type + + +@json_schema_type +class CreateBatchRequest(BaseModel): + """Request model for creating a batch.""" + + input_file_id: str = Field(..., description="The ID of an uploaded file containing requests for the batch.") + endpoint: str = Field(..., description="The endpoint to be used for all requests in the batch.") + completion_window: Literal["24h"] = Field( + ..., description="The time window within which the batch should be processed." + ) + metadata: dict[str, str] | None = Field(default=None, description="Optional metadata for the batch.") + idempotency_key: str | None = Field( + default=None, description="Optional idempotency key. When provided, enables idempotent behavior." + ) + + +# Re-export response models for convenience +__all__ = ["CreateBatchRequest", "BatchObject", "ListBatchesResponse"]