mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-06 10:37:22 +00:00
feat: Implement FastAPI router system (#4191)
# What does this PR do? This commit introduces a new FastAPI router-based system for defining API endpoints, enabling a migration path away from the legacy @webmethod decorator system. The implementation includes router infrastructure, migration of the Batches API as the first example, and updates to server, OpenAPI generation, and inspection systems to support both routing approaches. The router infrastructure consists of a router registry system that allows APIs to register FastAPI router factories, which are then automatically discovered and included in the server application. Standard error responses are centralized in router_utils to ensure consistent OpenAPI specification generation with proper $ref references to component responses. The Batches API has been migrated to demonstrate the new pattern. The protocol definition and models remain in llama_stack_api/batches, maintaining clear separation between API contracts and server implementation. The FastAPI router implementation lives in llama_stack/core/server/routers/batches, following the established pattern where API contracts are defined in llama_stack_api and server routing logic lives in llama_stack/core/server. The server now checks for registered routers before falling back to the legacy webmethod-based route discovery, ensuring backward compatibility during the migration period. The OpenAPI generator has been updated to handle both router-based and webmethod-based routes, correctly extracting metadata from FastAPI route decorators and Pydantic Field descriptions. The inspect endpoint now includes routes from both systems, with proper filtering for deprecated routes and API levels. Response descriptions are now explicitly defined in router decorators, ensuring the generated OpenAPI specification matches the previous format. Error responses use $ref references to component responses (BadRequest400, TooManyRequests429, etc.) as required by the specification. This is neat and will allow us to remove a lot of boiler plate code from our generator once the migration is done. This implementation provides a foundation for incrementally migrating other APIs to the router system while maintaining full backward compatibility with existing webmethod-based APIs. Closes: https://github.com/llamastack/llama-stack/issues/4188 ## Test Plan CI, the server should start, same routes should be visible. ``` curl http://localhost:8321/v1/inspect/routes | jq '.data[] | select(.route | contains("batches"))' ``` Also: ``` uv run pytest tests/integration/batches/ -vv --stack-config=http://localhost:8321 ================================================== test session starts ================================================== platform darwin -- Python 3.12.8, pytest-8.4.2, pluggy-1.6.0 -- /Users/leseb/Documents/AI/llama-stack/.venv/bin/python3 cachedir: .pytest_cache metadata: {'Python': '3.12.8', 'Platform': 'macOS-26.0.1-arm64-arm-64bit', 'Packages': {'pytest': '8.4.2', 'pluggy': '1.6.0'}, 'Plugins': {'anyio': '4.9.0', 'html': '4.1.1', 'socket': '0.7.0', 'asyncio': '1.1.0', 'json-report': '1.5.0', 'timeout': '2.4.0', 'metadata': '3.1.1', 'cov': '6.2.1', 'nbval': '0.11.0'}} rootdir: /Users/leseb/Documents/AI/llama-stack configfile: pyproject.toml plugins: anyio-4.9.0, html-4.1.1, socket-0.7.0, asyncio-1.1.0, json-report-1.5.0, timeout-2.4.0, metadata-3.1.1, cov-6.2.1, nbval-0.11.0 asyncio: mode=Mode.AUTO, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function collected 24 items tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_creation_and_retrieval[None] SKIPPED [ 4%] tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_listing[None] SKIPPED [ 8%] tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_immediate_cancellation[None] SKIPPED [ 12%] tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_e2e_chat_completions[None] SKIPPED [ 16%] tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_e2e_completions[None] SKIPPED [ 20%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_invalid_endpoint[None] SKIPPED [ 25%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_cancel_completed[None] SKIPPED [ 29%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_missing_required_fields[None] SKIPPED [ 33%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_invalid_completion_window[None] SKIPPED [ 37%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_streaming_not_supported[None] SKIPPED [ 41%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_mixed_streaming_requests[None] SKIPPED [ 45%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_endpoint_mismatch[None] SKIPPED [ 50%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_missing_required_body_fields[None] SKIPPED [ 54%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_invalid_metadata_types[None] SKIPPED [ 58%] tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_e2e_embeddings[None] SKIPPED [ 62%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_nonexistent_file_id PASSED [ 66%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_malformed_jsonl PASSED [ 70%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_file_malformed_batch_file[empty] XFAIL [ 75%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_file_malformed_batch_file[malformed] XFAIL [ 79%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_retrieve_nonexistent PASSED [ 83%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_cancel_nonexistent PASSED [ 87%] tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_error_handling_invalid_model PASSED [ 91%] tests/integration/batches/test_batches_idempotency.py::TestBatchesIdempotencyIntegration::test_idempotent_batch_creation_successful PASSED [ 95%] tests/integration/batches/test_batches_idempotency.py::TestBatchesIdempotencyIntegration::test_idempotency_conflict_with_different_params PASSED [100%] ================================================= slowest 10 durations ================================================== 1.01s call tests/integration/batches/test_batches_idempotency.py::TestBatchesIdempotencyIntegration::test_idempotent_batch_creation_successful 0.21s call tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_nonexistent_file_id 0.17s call tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_malformed_jsonl 0.12s call tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_error_handling_invalid_model 0.05s setup tests/integration/batches/test_batches.py::TestBatchesIntegration::test_batch_creation_and_retrieval[None] 0.02s call tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_file_malformed_batch_file[empty] 0.01s call tests/integration/batches/test_batches_idempotency.py::TestBatchesIdempotencyIntegration::test_idempotency_conflict_with_different_params 0.01s call tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_file_malformed_batch_file[malformed] 0.01s call tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_retrieve_nonexistent 0.00s call tests/integration/batches/test_batches_errors.py::TestBatchesErrorHandling::test_batch_cancel_nonexistent ======================================= 7 passed, 15 skipped, 2 xfailed in 1.78s ======================================== ``` --------- Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
4237eb4aaa
commit
7f43051a63
22 changed files with 1095 additions and 248 deletions
|
|
@ -26,7 +26,15 @@ from . import common # noqa: F401
|
|||
|
||||
# Import all public API symbols
|
||||
from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec
|
||||
from .batches import Batches, BatchObject, ListBatchesResponse
|
||||
from .batches import (
|
||||
Batches,
|
||||
BatchObject,
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
ListBatchesResponse,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
from .benchmarks import (
|
||||
Benchmark,
|
||||
BenchmarkInput,
|
||||
|
|
@ -462,6 +470,9 @@ __all__ = [
|
|||
"BasicScoringFnParams",
|
||||
"Batches",
|
||||
"BatchObject",
|
||||
"CancelBatchRequest",
|
||||
"CreateBatchRequest",
|
||||
"ListBatchesRequest",
|
||||
"Benchmark",
|
||||
"BenchmarkConfig",
|
||||
"BenchmarkInput",
|
||||
|
|
@ -555,6 +566,7 @@ __all__ = [
|
|||
"LLMAsJudgeScoringFnParams",
|
||||
"LLMRAGQueryGeneratorConfig",
|
||||
"ListBatchesResponse",
|
||||
"RetrieveBatchRequest",
|
||||
"ListBenchmarksResponse",
|
||||
"ListDatasetsResponse",
|
||||
"ListModelsResponse",
|
||||
|
|
|
|||
|
|
@ -1,96 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack_api.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
||||
try:
|
||||
from openai.types import Batch as BatchObject
|
||||
except ImportError as e:
|
||||
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesResponse(BaseModel):
|
||||
"""Response containing a list of batch objects."""
|
||||
|
||||
object: Literal["list"] = "list"
|
||||
data: list[BatchObject] = Field(..., description="List of batch objects")
|
||||
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
|
||||
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
|
||||
has_more: bool = Field(default=False, description="Whether there are more batches available")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Batches(Protocol):
|
||||
"""
|
||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||
cost-effective inference at scale.
|
||||
|
||||
The API is designed to allow use of openai client libraries for seamless integration.
|
||||
|
||||
This API provides the following extensions:
|
||||
- idempotent batch creation
|
||||
|
||||
Note: This API is currently under active development and may undergo changes.
|
||||
"""
|
||||
|
||||
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
idempotency_key: str | None = None,
|
||||
) -> BatchObject:
|
||||
"""Create a new batch for processing multiple API requests.
|
||||
|
||||
:param input_file_id: The ID of an uploaded file containing requests for the batch.
|
||||
:param endpoint: The endpoint to be used for all requests in the batch.
|
||||
:param completion_window: The time window within which the batch should be processed.
|
||||
:param metadata: Optional metadata for the batch.
|
||||
:param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior.
|
||||
:returns: The created batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Retrieve information about a specific batch.
|
||||
|
||||
:param batch_id: The ID of the batch to retrieve.
|
||||
:returns: The batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Cancel a batch that is in progress.
|
||||
|
||||
:param batch_id: The ID of the batch to cancel.
|
||||
:returns: The updated batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListBatchesResponse:
|
||||
"""List all batches for the current user.
|
||||
|
||||
:param after: A cursor for pagination; returns batches after this batch ID.
|
||||
:param limit: Number of batches to return (default 20, max 100).
|
||||
:returns: A list of batch objects.
|
||||
"""
|
||||
...
|
||||
40
src/llama_stack_api/batches/__init__.py
Normal file
40
src/llama_stack_api/batches/__init__.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Batches API protocol and models.
|
||||
|
||||
This module contains the Batches protocol definition.
|
||||
Pydantic models are defined in llama_stack_api.batches.models.
|
||||
The FastAPI router is defined in llama_stack_api.batches.fastapi_routes.
|
||||
"""
|
||||
|
||||
from openai.types import Batch as BatchObject
|
||||
|
||||
# Import fastapi_routes for router factory access
|
||||
from . import fastapi_routes
|
||||
|
||||
# Import protocol for re-export
|
||||
from .api import Batches
|
||||
|
||||
# Import models for re-export
|
||||
from .models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
ListBatchesResponse,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Batches",
|
||||
"BatchObject",
|
||||
"CancelBatchRequest",
|
||||
"CreateBatchRequest",
|
||||
"ListBatchesRequest",
|
||||
"ListBatchesResponse",
|
||||
"RetrieveBatchRequest",
|
||||
"fastapi_routes",
|
||||
]
|
||||
53
src/llama_stack_api/batches/api.py
Normal file
53
src/llama_stack_api/batches/api.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from openai.types import Batch as BatchObject
|
||||
|
||||
from .models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
ListBatchesResponse,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Batches(Protocol):
|
||||
"""
|
||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||
cost-effective inference at scale.
|
||||
|
||||
The API is designed to allow use of openai client libraries for seamless integration.
|
||||
|
||||
This API provides the following extensions:
|
||||
- idempotent batch creation
|
||||
|
||||
Note: This API is currently under active development and may undergo changes.
|
||||
"""
|
||||
|
||||
async def create_batch(
|
||||
self,
|
||||
request: CreateBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def retrieve_batch(
|
||||
self,
|
||||
request: RetrieveBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def cancel_batch(
|
||||
self,
|
||||
request: CancelBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def list_batches(
|
||||
self,
|
||||
request: ListBatchesRequest,
|
||||
) -> ListBatchesResponse: ...
|
||||
113
src/llama_stack_api/batches/fastapi_routes.py
Normal file
113
src/llama_stack_api/batches/fastapi_routes.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""FastAPI router for the Batches API.
|
||||
|
||||
This module defines the FastAPI router for the Batches API using standard
|
||||
FastAPI route decorators. The router is defined in the API package to keep
|
||||
all API-related code together.
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
|
||||
from llama_stack_api.batches.models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
from llama_stack_api.router_utils import create_path_dependency, create_query_dependency, standard_responses
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
||||
from .api import Batches
|
||||
from .models import BatchObject, ListBatchesResponse
|
||||
|
||||
# Automatically generate dependency functions from Pydantic models
|
||||
# This ensures the models are the single source of truth for descriptions
|
||||
get_retrieve_batch_request = create_path_dependency(RetrieveBatchRequest)
|
||||
get_cancel_batch_request = create_path_dependency(CancelBatchRequest)
|
||||
|
||||
|
||||
# Automatically generate dependency function from Pydantic model
|
||||
# This ensures the model is the single source of truth for descriptions and defaults
|
||||
get_list_batches_request = create_query_dependency(ListBatchesRequest)
|
||||
|
||||
|
||||
def create_router(impl: Batches) -> APIRouter:
|
||||
"""Create a FastAPI router for the Batches API.
|
||||
|
||||
Args:
|
||||
impl: The Batches implementation instance
|
||||
|
||||
Returns:
|
||||
APIRouter configured for the Batches API
|
||||
"""
|
||||
router = APIRouter(
|
||||
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||
tags=["Batches"],
|
||||
responses=standard_responses,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/batches",
|
||||
response_model=BatchObject,
|
||||
summary="Create a new batch for processing multiple API requests.",
|
||||
description="Create a new batch for processing multiple API requests.",
|
||||
responses={
|
||||
200: {"description": "The created batch object."},
|
||||
409: {"description": "Conflict: The idempotency key was previously used with different parameters."},
|
||||
},
|
||||
)
|
||||
async def create_batch(
|
||||
request: Annotated[CreateBatchRequest, Body(...)],
|
||||
) -> BatchObject:
|
||||
return await impl.create_batch(request)
|
||||
|
||||
@router.get(
|
||||
"/batches/{batch_id}",
|
||||
response_model=BatchObject,
|
||||
summary="Retrieve information about a specific batch.",
|
||||
description="Retrieve information about a specific batch.",
|
||||
responses={
|
||||
200: {"description": "The batch object."},
|
||||
},
|
||||
)
|
||||
async def retrieve_batch(
|
||||
request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)],
|
||||
) -> BatchObject:
|
||||
return await impl.retrieve_batch(request)
|
||||
|
||||
@router.post(
|
||||
"/batches/{batch_id}/cancel",
|
||||
response_model=BatchObject,
|
||||
summary="Cancel a batch that is in progress.",
|
||||
description="Cancel a batch that is in progress.",
|
||||
responses={
|
||||
200: {"description": "The updated batch object."},
|
||||
},
|
||||
)
|
||||
async def cancel_batch(
|
||||
request: Annotated[CancelBatchRequest, Depends(get_cancel_batch_request)],
|
||||
) -> BatchObject:
|
||||
return await impl.cancel_batch(request)
|
||||
|
||||
@router.get(
|
||||
"/batches",
|
||||
response_model=ListBatchesResponse,
|
||||
summary="List all batches for the current user.",
|
||||
description="List all batches for the current user.",
|
||||
responses={
|
||||
200: {"description": "A list of batch objects."},
|
||||
},
|
||||
)
|
||||
async def list_batches(
|
||||
request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)],
|
||||
) -> ListBatchesResponse:
|
||||
return await impl.list_batches(request)
|
||||
|
||||
return router
|
||||
78
src/llama_stack_api/batches/models.py
Normal file
78
src/llama_stack_api/batches/models.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Pydantic models for Batches API requests and responses.
|
||||
|
||||
This module defines the request and response models for the Batches API
|
||||
using Pydantic with Field descriptions for OpenAPI schema generation.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from openai.types import Batch as BatchObject
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack_api.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CreateBatchRequest(BaseModel):
|
||||
"""Request model for creating a batch."""
|
||||
|
||||
input_file_id: str = Field(..., description="The ID of an uploaded file containing requests for the batch.")
|
||||
endpoint: str = Field(..., description="The endpoint to be used for all requests in the batch.")
|
||||
completion_window: Literal["24h"] = Field(
|
||||
..., description="The time window within which the batch should be processed."
|
||||
)
|
||||
metadata: dict[str, str] | None = Field(default=None, description="Optional metadata for the batch.")
|
||||
idempotency_key: str | None = Field(
|
||||
default=None, description="Optional idempotency key. When provided, enables idempotent behavior."
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesRequest(BaseModel):
|
||||
"""Request model for listing batches."""
|
||||
|
||||
after: str | None = Field(
|
||||
default=None, description="Optional cursor for pagination. Returns batches after this ID."
|
||||
)
|
||||
limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RetrieveBatchRequest(BaseModel):
|
||||
"""Request model for retrieving a batch."""
|
||||
|
||||
batch_id: str = Field(..., description="The ID of the batch to retrieve.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CancelBatchRequest(BaseModel):
|
||||
"""Request model for canceling a batch."""
|
||||
|
||||
batch_id: str = Field(..., description="The ID of the batch to cancel.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesResponse(BaseModel):
|
||||
"""Response containing a list of batch objects."""
|
||||
|
||||
object: Literal["list"] = "list"
|
||||
data: list[BatchObject] = Field(..., description="List of batch objects")
|
||||
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
|
||||
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
|
||||
has_more: bool = Field(default=False, description="Whether there are more batches available")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CreateBatchRequest",
|
||||
"ListBatchesRequest",
|
||||
"RetrieveBatchRequest",
|
||||
"CancelBatchRequest",
|
||||
"ListBatchesResponse",
|
||||
"BatchObject",
|
||||
]
|
||||
|
|
@ -24,6 +24,7 @@ classifiers = [
|
|||
"Topic :: Scientific/Engineering :: Information Analysis",
|
||||
]
|
||||
dependencies = [
|
||||
"fastapi>=0.115.0,<1.0",
|
||||
"pydantic>=2.11.9",
|
||||
"jsonschema",
|
||||
"opentelemetry-sdk>=1.30.0",
|
||||
|
|
|
|||
155
src/llama_stack_api/router_utils.py
Normal file
155
src/llama_stack_api/router_utils.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Utilities for creating FastAPI routers with standard error responses.
|
||||
|
||||
This module provides standard error response definitions for FastAPI routers.
|
||||
These responses use OpenAPI $ref references to component responses defined
|
||||
in the OpenAPI specification.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from typing import Annotated, Any, TypeVar
|
||||
|
||||
from fastapi import Path, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
standard_responses: dict[int | str, dict[str, Any]] = {
|
||||
400: {"$ref": "#/components/responses/BadRequest400"},
|
||||
429: {"$ref": "#/components/responses/TooManyRequests429"},
|
||||
500: {"$ref": "#/components/responses/InternalServerError500"},
|
||||
"default": {"$ref": "#/components/responses/DefaultError"},
|
||||
}
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
def create_query_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]:
|
||||
"""Create a FastAPI dependency function from a Pydantic model for query parameters.
|
||||
|
||||
FastAPI does not natively support using Pydantic models as query parameters
|
||||
without a dependency function. Using a dependency function typically leads to
|
||||
duplication: field types, default values, and descriptions must be repeated in
|
||||
`Query(...)` annotations even though they already exist in the Pydantic model.
|
||||
|
||||
This function automatically generates a dependency function that extracts query parameters
|
||||
from the request and constructs an instance of the Pydantic model. The descriptions and
|
||||
defaults are automatically extracted from the model's Field definitions, making the model
|
||||
the single source of truth.
|
||||
|
||||
Args:
|
||||
model_class: The Pydantic model class to create a dependency for
|
||||
|
||||
Returns:
|
||||
A dependency function that can be used with FastAPI's Depends()
|
||||
```
|
||||
"""
|
||||
# Build function signature dynamically from model fields
|
||||
annotations: dict[str, Any] = {}
|
||||
defaults: dict[str, Any] = {}
|
||||
|
||||
for field_name, field_info in model_class.model_fields.items():
|
||||
# Extract description from Field
|
||||
description = field_info.description
|
||||
|
||||
# Create Query annotation with description from model
|
||||
query_annotation = Query(description=description) if description else Query()
|
||||
|
||||
# Create Annotated type with Query
|
||||
field_type = field_info.annotation
|
||||
annotations[field_name] = Annotated[field_type, query_annotation]
|
||||
|
||||
# Set default value from model
|
||||
if field_info.default is not inspect.Parameter.empty:
|
||||
defaults[field_name] = field_info.default
|
||||
|
||||
# Create the dependency function dynamically
|
||||
def dependency_func(**kwargs: Any) -> T:
|
||||
return model_class(**kwargs)
|
||||
|
||||
# Set function signature
|
||||
sig_params = []
|
||||
for field_name, field_type in annotations.items():
|
||||
default = defaults.get(field_name, inspect.Parameter.empty)
|
||||
param = inspect.Parameter(
|
||||
field_name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
default=default,
|
||||
annotation=field_type,
|
||||
)
|
||||
sig_params.append(param)
|
||||
|
||||
# These attributes are set dynamically at runtime. While mypy can't verify them statically,
|
||||
# they are standard Python function attributes that exist on all callable objects at runtime.
|
||||
# Setting them allows FastAPI to properly introspect the function signature for dependency injection.
|
||||
dependency_func.__signature__ = inspect.Signature(sig_params) # type: ignore[attr-defined]
|
||||
dependency_func.__annotations__ = annotations # type: ignore[attr-defined]
|
||||
dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" # type: ignore[attr-defined]
|
||||
|
||||
return dependency_func
|
||||
|
||||
|
||||
def create_path_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]:
|
||||
"""Create a FastAPI dependency function from a Pydantic model for path parameters.
|
||||
|
||||
FastAPI requires path parameters to be explicitly annotated with `Path()`. When using
|
||||
a Pydantic model that contains path parameters, you typically need a dependency function
|
||||
that extracts the path parameter and constructs the model. This leads to duplication:
|
||||
the parameter name, type, and description must be repeated in `Path(...)` annotations
|
||||
even though they already exist in the Pydantic model.
|
||||
|
||||
This function automatically generates a dependency function that extracts path parameters
|
||||
from the request and constructs an instance of the Pydantic model. The descriptions are
|
||||
automatically extracted from the model's Field definitions, making the model the single
|
||||
source of truth.
|
||||
|
||||
Args:
|
||||
model_class: The Pydantic model class to create a dependency for. The model should
|
||||
have exactly one field that represents the path parameter.
|
||||
|
||||
Returns:
|
||||
A dependency function that can be used with FastAPI's Depends()
|
||||
```
|
||||
"""
|
||||
# Get the single field from the model (path parameter models typically have one field)
|
||||
if len(model_class.model_fields) != 1:
|
||||
raise ValueError(
|
||||
f"Path parameter model {model_class.__name__} must have exactly one field, "
|
||||
f"but has {len(model_class.model_fields)} fields"
|
||||
)
|
||||
|
||||
field_name, field_info = next(iter(model_class.model_fields.items()))
|
||||
|
||||
# Extract description from Field
|
||||
description = field_info.description
|
||||
|
||||
# Create Path annotation with description from model
|
||||
path_annotation = Path(description=description) if description else Path()
|
||||
|
||||
# Create Annotated type with Path
|
||||
field_type = field_info.annotation
|
||||
annotations: dict[str, Any] = {field_name: Annotated[field_type, path_annotation]}
|
||||
|
||||
# Create the dependency function dynamically
|
||||
def dependency_func(**kwargs: Any) -> T:
|
||||
return model_class(**kwargs)
|
||||
|
||||
# Set function signature
|
||||
param = inspect.Parameter(
|
||||
field_name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=annotations[field_name],
|
||||
)
|
||||
|
||||
# These attributes are set dynamically at runtime. While mypy can't verify them statically,
|
||||
# they are standard Python function attributes that exist on all callable objects at runtime.
|
||||
# Setting them allows FastAPI to properly introspect the function signature for dependency injection.
|
||||
dependency_func.__signature__ = inspect.Signature([param]) # type: ignore[attr-defined]
|
||||
dependency_func.__annotations__ = annotations # type: ignore[attr-defined]
|
||||
dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" # type: ignore[attr-defined]
|
||||
|
||||
return dependency_func
|
||||
Loading…
Add table
Add a link
Reference in a new issue