mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
Merge 7b93964a16 into 4237eb4aaa
This commit is contained in:
commit
cf949d7fac
22 changed files with 1086 additions and 248 deletions
|
|
@ -10,8 +10,14 @@ from pydantic import BaseModel
|
|||
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.server.fastapi_router_registry import (
|
||||
_ROUTER_FACTORIES,
|
||||
build_fastapi_router,
|
||||
get_router_routes,
|
||||
)
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack_api import (
|
||||
Api,
|
||||
HealthInfo,
|
||||
HealthStatus,
|
||||
Inspect,
|
||||
|
|
@ -43,6 +49,7 @@ class DistributionInspectImpl(Inspect):
|
|||
run_config: StackRunConfig = self.config.run_config
|
||||
|
||||
# Helper function to determine if a route should be included based on api_filter
|
||||
# TODO: remove this once we've migrated all APIs to FastAPI routers
|
||||
def should_include_route(webmethod) -> bool:
|
||||
if api_filter is None:
|
||||
# Default: only non-deprecated APIs
|
||||
|
|
@ -54,10 +61,62 @@ class DistributionInspectImpl(Inspect):
|
|||
# Filter by API level (non-deprecated routes only)
|
||||
return not webmethod.deprecated and webmethod.level == api_filter
|
||||
|
||||
# Helper function to get provider types for an API
|
||||
def _get_provider_types(api: Api) -> list[str]:
|
||||
if api.value in ["providers", "inspect"]:
|
||||
return [] # These APIs don't have "real" providers they're internal to the stack
|
||||
providers = run_config.providers.get(api.value, [])
|
||||
return [p.provider_type for p in providers] if providers else []
|
||||
|
||||
# Helper function to determine if a router route should be included based on api_filter
|
||||
def _should_include_router_route(route, router_prefix: str | None) -> bool:
|
||||
"""Check if a router-based route should be included based on api_filter."""
|
||||
# Check deprecated status
|
||||
route_deprecated = getattr(route, "deprecated", False) or False
|
||||
|
||||
if api_filter is None:
|
||||
# Default: only non-deprecated routes
|
||||
return not route_deprecated
|
||||
elif api_filter == "deprecated":
|
||||
# Special filter: show deprecated routes regardless of their actual level
|
||||
return route_deprecated
|
||||
else:
|
||||
# Filter by API level (non-deprecated routes only)
|
||||
# Extract level from router prefix (e.g., "/v1" -> "v1")
|
||||
if router_prefix:
|
||||
prefix_level = router_prefix.lstrip("/")
|
||||
return not route_deprecated and prefix_level == api_filter
|
||||
return not route_deprecated
|
||||
|
||||
ret = []
|
||||
external_apis = load_external_apis(run_config)
|
||||
all_endpoints = get_all_api_routes(external_apis)
|
||||
|
||||
# Process routes from APIs with FastAPI routers
|
||||
for api_name in _ROUTER_FACTORIES.keys():
|
||||
api = Api(api_name)
|
||||
router = build_fastapi_router(api, None) # we don't need the impl here, just the routes
|
||||
if router:
|
||||
router_routes = get_router_routes(router)
|
||||
for route in router_routes:
|
||||
if _should_include_router_route(route, router.prefix):
|
||||
if route.methods is not None:
|
||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
||||
if available_methods:
|
||||
ret.append(
|
||||
RouteInfo(
|
||||
route=route.path,
|
||||
method=available_methods[0],
|
||||
provider_types=_get_provider_types(api),
|
||||
)
|
||||
)
|
||||
|
||||
# Process routes from legacy webmethod-based APIs
|
||||
for api, endpoints in all_endpoints.items():
|
||||
# Skip APIs that have routers (already processed above)
|
||||
if api.value in _ROUTER_FACTORIES:
|
||||
continue
|
||||
|
||||
# Always include provider and inspect APIs, filter others based on run config
|
||||
if api.value in ["providers", "inspect"]:
|
||||
ret.extend(
|
||||
|
|
|
|||
84
src/llama_stack/core/server/fastapi_router_registry.py
Normal file
84
src/llama_stack/core/server/fastapi_router_registry.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Router utilities for FastAPI routers.
|
||||
|
||||
This module provides utilities to create FastAPI routers from API packages.
|
||||
APIs with routers are explicitly listed here.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.routing import APIRoute
|
||||
from starlette.routing import Route
|
||||
|
||||
# Router factories for APIs that have FastAPI routers
|
||||
# Add new APIs here as they are migrated to the router system
|
||||
from llama_stack_api.batches.fastapi_routes import create_router as create_batches_router
|
||||
from llama_stack_api.datatypes import Api
|
||||
|
||||
_ROUTER_FACTORIES: dict[str, Callable[[Any], APIRouter]] = {
|
||||
"batches": create_batches_router,
|
||||
}
|
||||
|
||||
|
||||
def build_fastapi_router(api: "Api", impl: Any) -> APIRouter | None:
|
||||
"""Build a router for an API by combining its router factory with the implementation.
|
||||
|
||||
Args:
|
||||
api: The API enum value
|
||||
impl: The implementation instance for the API
|
||||
|
||||
Returns:
|
||||
APIRouter if the API has a router factory, None otherwise
|
||||
"""
|
||||
router_factory = _ROUTER_FACTORIES.get(api.value)
|
||||
if router_factory is None:
|
||||
return None
|
||||
|
||||
# cast is safe here: all router factories in API packages are required to return APIRouter.
|
||||
# If a router factory returns the wrong type, it will fail at runtime when
|
||||
# app.include_router(router) is called
|
||||
return cast(APIRouter, router_factory(impl))
|
||||
|
||||
|
||||
def get_router_routes(router: APIRouter) -> list[Route]:
|
||||
"""Extract routes from a FastAPI router.
|
||||
|
||||
Args:
|
||||
router: The FastAPI router to extract routes from
|
||||
|
||||
Returns:
|
||||
List of Route objects from the router
|
||||
"""
|
||||
routes = []
|
||||
|
||||
for route in router.routes:
|
||||
# FastAPI routers use APIRoute objects, which have path and methods attributes
|
||||
if isinstance(route, APIRoute):
|
||||
# Combine router prefix with route path
|
||||
routes.append(
|
||||
Route(
|
||||
path=route.path,
|
||||
methods=route.methods,
|
||||
name=route.name,
|
||||
endpoint=route.endpoint,
|
||||
)
|
||||
)
|
||||
elif isinstance(route, Route):
|
||||
# Fallback for regular Starlette Route objects
|
||||
routes.append(
|
||||
Route(
|
||||
path=route.path,
|
||||
methods=route.methods,
|
||||
name=route.name,
|
||||
endpoint=route.endpoint,
|
||||
)
|
||||
)
|
||||
|
||||
return routes
|
||||
|
|
@ -26,6 +26,18 @@ RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
|
|||
def get_all_api_routes(
|
||||
external_apis: dict[Api, ExternalApiSpec] | None = None,
|
||||
) -> dict[Api, list[tuple[Route, WebMethod]]]:
|
||||
"""Get all API routes from webmethod-based protocols.
|
||||
|
||||
This function only returns routes from APIs that use the legacy @webmethod
|
||||
decorator system. For APIs that have been migrated to FastAPI routers,
|
||||
use the router registry (fastapi_router_registry.has_router() and fastapi_router_registry.build_fastapi_router()).
|
||||
|
||||
Args:
|
||||
external_apis: Optional dictionary of external API specifications
|
||||
|
||||
Returns:
|
||||
Dictionary mapping API to list of (Route, WebMethod) tuples
|
||||
"""
|
||||
apis = {}
|
||||
|
||||
protocols = api_protocol_map(external_apis)
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ from llama_stack.core.request_headers import (
|
|||
request_provider_data_context,
|
||||
user_from_scope,
|
||||
)
|
||||
from llama_stack.core.server.fastapi_router_registry import build_fastapi_router
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack.core.stack import (
|
||||
Stack,
|
||||
|
|
@ -84,7 +85,7 @@ def create_sse_event(data: Any) -> str:
|
|||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
traceback.print_exception(exc)
|
||||
traceback.print_exception(type(exc), exc, exc.__traceback__)
|
||||
http_exc = translate_exception(exc)
|
||||
|
||||
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
|
||||
|
|
@ -454,15 +455,22 @@ def create_app() -> StackApp:
|
|||
apis_to_serve.add("providers")
|
||||
apis_to_serve.add("prompts")
|
||||
apis_to_serve.add("conversations")
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
routes = all_routes[api]
|
||||
try:
|
||||
impl = impls[api]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Could not find provider implementation for {api} API") from e
|
||||
# Try to discover and use a router factory from the API package
|
||||
impl = impls[api]
|
||||
router = build_fastapi_router(api, impl)
|
||||
if router:
|
||||
app.include_router(router)
|
||||
logger.debug(f"Registered FastAPIrouter for {api} API")
|
||||
continue
|
||||
|
||||
# Fall back to old webmethod-based route discovery until the migration is complete
|
||||
impl = impls[api]
|
||||
|
||||
routes = all_routes[api]
|
||||
for route, _ in routes:
|
||||
if not hasattr(impl, route.name):
|
||||
# ideally this should be a typing violation already
|
||||
|
|
@ -488,7 +496,15 @@ def create_app() -> StackApp:
|
|||
|
||||
logger.debug(f"serving APIs: {apis_to_serve}")
|
||||
|
||||
# Register specific exception handlers before the generic Exception handler
|
||||
# This prevents the re-raising behavior that causes connection resets
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(ConflictError)(global_exception_handler)
|
||||
app.exception_handler(ResourceNotFoundError)(global_exception_handler)
|
||||
app.exception_handler(AuthenticationRequiredError)(global_exception_handler)
|
||||
app.exception_handler(AccessDeniedError)(global_exception_handler)
|
||||
app.exception_handler(BadRequestError)(global_exception_handler)
|
||||
# Generic Exception handler should be last
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
return app
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import json
|
|||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
from openai.types.batch import BatchError, Errors
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -38,6 +38,12 @@ from llama_stack_api import (
|
|||
OpenAIUserMessageParam,
|
||||
ResourceNotFoundError,
|
||||
)
|
||||
from llama_stack_api.batches.models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
|
||||
from .config import ReferenceBatchesImplConfig
|
||||
|
||||
|
|
@ -140,11 +146,7 @@ class ReferenceBatchesImpl(Batches):
|
|||
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
idempotency_key: str | None = None,
|
||||
request: CreateBatchRequest,
|
||||
) -> BatchObject:
|
||||
"""
|
||||
Create a new batch for processing multiple API requests.
|
||||
|
|
@ -185,14 +187,14 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
# TODO: set expiration time for garbage collection
|
||||
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
|
||||
if request.endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
|
||||
raise ValueError(
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
|
||||
f"Invalid endpoint: {request.endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
|
||||
)
|
||||
|
||||
if completion_window != "24h":
|
||||
if request.completion_window != "24h":
|
||||
raise ValueError(
|
||||
f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
|
||||
f"Invalid completion_window: {request.completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
|
||||
)
|
||||
|
||||
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
|
||||
|
|
@ -200,22 +202,22 @@ class ReferenceBatchesImpl(Batches):
|
|||
# For idempotent requests, use the idempotency key for the batch ID
|
||||
# This ensures the same key always maps to the same batch ID,
|
||||
# allowing us to detect parameter conflicts
|
||||
if idempotency_key is not None:
|
||||
hash_input = idempotency_key.encode("utf-8")
|
||||
if request.idempotency_key is not None:
|
||||
hash_input = request.idempotency_key.encode("utf-8")
|
||||
hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
|
||||
batch_id = f"batch_{hash_digest}"
|
||||
|
||||
try:
|
||||
existing_batch = await self.retrieve_batch(batch_id)
|
||||
existing_batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
|
||||
|
||||
if (
|
||||
existing_batch.input_file_id != input_file_id
|
||||
or existing_batch.endpoint != endpoint
|
||||
or existing_batch.completion_window != completion_window
|
||||
or existing_batch.metadata != metadata
|
||||
existing_batch.input_file_id != request.input_file_id
|
||||
or existing_batch.endpoint != request.endpoint
|
||||
or existing_batch.completion_window != request.completion_window
|
||||
or existing_batch.metadata != request.metadata
|
||||
):
|
||||
raise ConflictError(
|
||||
f"Idempotency key '{idempotency_key}' was previously used with different parameters. "
|
||||
f"Idempotency key '{request.idempotency_key}' was previously used with different parameters. "
|
||||
"Either use a new idempotency key or ensure all parameters match the original request."
|
||||
)
|
||||
|
||||
|
|
@ -230,12 +232,12 @@ class ReferenceBatchesImpl(Batches):
|
|||
batch = BatchObject(
|
||||
id=batch_id,
|
||||
object="batch",
|
||||
endpoint=endpoint,
|
||||
input_file_id=input_file_id,
|
||||
completion_window=completion_window,
|
||||
endpoint=request.endpoint,
|
||||
input_file_id=request.input_file_id,
|
||||
completion_window=request.completion_window,
|
||||
status="validating",
|
||||
created_at=current_time,
|
||||
metadata=metadata,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
|
||||
|
|
@ -247,28 +249,27 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
return batch
|
||||
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
async def cancel_batch(self, request: CancelBatchRequest) -> BatchObject:
|
||||
"""Cancel a batch that is in progress."""
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=request.batch_id))
|
||||
|
||||
if batch.status in ["cancelled", "cancelling"]:
|
||||
return batch
|
||||
|
||||
if batch.status in ["completed", "failed", "expired"]:
|
||||
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
|
||||
raise ConflictError(f"Cannot cancel batch '{request.batch_id}' with status '{batch.status}'")
|
||||
|
||||
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
|
||||
await self._update_batch(request.batch_id, status="cancelling", cancelling_at=int(time.time()))
|
||||
|
||||
if batch_id in self._processing_tasks:
|
||||
self._processing_tasks[batch_id].cancel()
|
||||
if request.batch_id in self._processing_tasks:
|
||||
self._processing_tasks[request.batch_id].cancel()
|
||||
# note: task removal and status="cancelled" handled in finally block of _process_batch
|
||||
|
||||
return await self.retrieve_batch(batch_id)
|
||||
return await self.retrieve_batch(RetrieveBatchRequest(batch_id=request.batch_id))
|
||||
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int = 20,
|
||||
request: ListBatchesRequest,
|
||||
) -> ListBatchesResponse:
|
||||
"""
|
||||
List all batches, eventually only for the current user.
|
||||
|
|
@ -285,14 +286,14 @@ class ReferenceBatchesImpl(Batches):
|
|||
batches.sort(key=lambda b: b.created_at, reverse=True)
|
||||
|
||||
start_idx = 0
|
||||
if after:
|
||||
if request.after:
|
||||
for i, batch in enumerate(batches):
|
||||
if batch.id == after:
|
||||
if batch.id == request.after:
|
||||
start_idx = i + 1
|
||||
break
|
||||
|
||||
page_batches = batches[start_idx : start_idx + limit]
|
||||
has_more = (start_idx + limit) < len(batches)
|
||||
page_batches = batches[start_idx : start_idx + request.limit]
|
||||
has_more = (start_idx + request.limit) < len(batches)
|
||||
|
||||
first_id = page_batches[0].id if page_batches else None
|
||||
last_id = page_batches[-1].id if page_batches else None
|
||||
|
|
@ -304,11 +305,11 @@ class ReferenceBatchesImpl(Batches):
|
|||
has_more=has_more,
|
||||
)
|
||||
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
async def retrieve_batch(self, request: RetrieveBatchRequest) -> BatchObject:
|
||||
"""Retrieve information about a specific batch."""
|
||||
batch_data = await self.kvstore.get(f"batch:{batch_id}")
|
||||
batch_data = await self.kvstore.get(f"batch:{request.batch_id}")
|
||||
if not batch_data:
|
||||
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
|
||||
raise ResourceNotFoundError(request.batch_id, "Batch", "batches.list()")
|
||||
|
||||
return BatchObject.model_validate_json(batch_data)
|
||||
|
||||
|
|
@ -316,7 +317,7 @@ class ReferenceBatchesImpl(Batches):
|
|||
"""Update batch fields in kvstore."""
|
||||
async with self._update_batch_lock:
|
||||
try:
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
|
||||
|
||||
# batch processing is async. once cancelling, only allow "cancelled" status updates
|
||||
if batch.status == "cancelling" and updates.get("status") != "cancelled":
|
||||
|
|
@ -536,7 +537,7 @@ class ReferenceBatchesImpl(Batches):
|
|||
async def _process_batch_impl(self, batch_id: str) -> None:
|
||||
"""Implementation of batch processing logic."""
|
||||
errors: list[BatchError] = []
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
|
||||
|
||||
errors, requests = await self._validate_input(batch)
|
||||
if errors:
|
||||
|
|
|
|||
|
|
@ -26,7 +26,15 @@ from . import common # noqa: F401
|
|||
|
||||
# Import all public API symbols
|
||||
from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec
|
||||
from .batches import Batches, BatchObject, ListBatchesResponse
|
||||
from .batches import (
|
||||
Batches,
|
||||
BatchObject,
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
ListBatchesResponse,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
from .benchmarks import (
|
||||
Benchmark,
|
||||
BenchmarkInput,
|
||||
|
|
@ -462,6 +470,9 @@ __all__ = [
|
|||
"BasicScoringFnParams",
|
||||
"Batches",
|
||||
"BatchObject",
|
||||
"CancelBatchRequest",
|
||||
"CreateBatchRequest",
|
||||
"ListBatchesRequest",
|
||||
"Benchmark",
|
||||
"BenchmarkConfig",
|
||||
"BenchmarkInput",
|
||||
|
|
@ -555,6 +566,7 @@ __all__ = [
|
|||
"LLMAsJudgeScoringFnParams",
|
||||
"LLMRAGQueryGeneratorConfig",
|
||||
"ListBatchesResponse",
|
||||
"RetrieveBatchRequest",
|
||||
"ListBenchmarksResponse",
|
||||
"ListDatasetsResponse",
|
||||
"ListModelsResponse",
|
||||
|
|
|
|||
|
|
@ -1,96 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack_api.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
||||
try:
|
||||
from openai.types import Batch as BatchObject
|
||||
except ImportError as e:
|
||||
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesResponse(BaseModel):
|
||||
"""Response containing a list of batch objects."""
|
||||
|
||||
object: Literal["list"] = "list"
|
||||
data: list[BatchObject] = Field(..., description="List of batch objects")
|
||||
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
|
||||
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
|
||||
has_more: bool = Field(default=False, description="Whether there are more batches available")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Batches(Protocol):
|
||||
"""
|
||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||
cost-effective inference at scale.
|
||||
|
||||
The API is designed to allow use of openai client libraries for seamless integration.
|
||||
|
||||
This API provides the following extensions:
|
||||
- idempotent batch creation
|
||||
|
||||
Note: This API is currently under active development and may undergo changes.
|
||||
"""
|
||||
|
||||
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
idempotency_key: str | None = None,
|
||||
) -> BatchObject:
|
||||
"""Create a new batch for processing multiple API requests.
|
||||
|
||||
:param input_file_id: The ID of an uploaded file containing requests for the batch.
|
||||
:param endpoint: The endpoint to be used for all requests in the batch.
|
||||
:param completion_window: The time window within which the batch should be processed.
|
||||
:param metadata: Optional metadata for the batch.
|
||||
:param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior.
|
||||
:returns: The created batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Retrieve information about a specific batch.
|
||||
|
||||
:param batch_id: The ID of the batch to retrieve.
|
||||
:returns: The batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Cancel a batch that is in progress.
|
||||
|
||||
:param batch_id: The ID of the batch to cancel.
|
||||
:returns: The updated batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListBatchesResponse:
|
||||
"""List all batches for the current user.
|
||||
|
||||
:param after: A cursor for pagination; returns batches after this batch ID.
|
||||
:param limit: Number of batches to return (default 20, max 100).
|
||||
:returns: A list of batch objects.
|
||||
"""
|
||||
...
|
||||
39
src/llama_stack_api/batches/__init__.py
Normal file
39
src/llama_stack_api/batches/__init__.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Batches API protocol and models.
|
||||
|
||||
This module contains the Batches protocol definition.
|
||||
Pydantic models are defined in llama_stack_api.batches.models.
|
||||
The FastAPI router is defined in llama_stack_api.batches.fastapi_routes.
|
||||
"""
|
||||
|
||||
try:
|
||||
from openai.types import Batch as BatchObject
|
||||
except ImportError as e:
|
||||
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||
|
||||
# Import protocol for re-export
|
||||
from llama_stack_api.batches.api import Batches
|
||||
|
||||
# Import models for re-export
|
||||
from llama_stack_api.batches.models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
ListBatchesResponse,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Batches",
|
||||
"BatchObject",
|
||||
"CreateBatchRequest",
|
||||
"ListBatchesRequest",
|
||||
"RetrieveBatchRequest",
|
||||
"CancelBatchRequest",
|
||||
"ListBatchesResponse",
|
||||
]
|
||||
56
src/llama_stack_api/batches/api.py
Normal file
56
src/llama_stack_api/batches/api.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
try:
|
||||
from openai.types import Batch as BatchObject
|
||||
except ImportError as e:
|
||||
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||
|
||||
from llama_stack_api.batches.models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
ListBatchesResponse,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Batches(Protocol):
|
||||
"""
|
||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||
cost-effective inference at scale.
|
||||
|
||||
The API is designed to allow use of openai client libraries for seamless integration.
|
||||
|
||||
This API provides the following extensions:
|
||||
- idempotent batch creation
|
||||
|
||||
Note: This API is currently under active development and may undergo changes.
|
||||
"""
|
||||
|
||||
async def create_batch(
|
||||
self,
|
||||
request: CreateBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def retrieve_batch(
|
||||
self,
|
||||
request: RetrieveBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def cancel_batch(
|
||||
self,
|
||||
request: CancelBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def list_batches(
|
||||
self,
|
||||
request: ListBatchesRequest,
|
||||
) -> ListBatchesResponse: ...
|
||||
111
src/llama_stack_api/batches/fastapi_routes.py
Normal file
111
src/llama_stack_api/batches/fastapi_routes.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""FastAPI router for the Batches API.
|
||||
|
||||
This module defines the FastAPI router for the Batches API using standard
|
||||
FastAPI route decorators. The router is defined in the API package to keep
|
||||
all API-related code together.
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
|
||||
from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse
|
||||
from llama_stack_api.batches.models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
from llama_stack_api.router_utils import create_path_dependency, create_query_dependency, standard_responses
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
||||
# Automatically generate dependency functions from Pydantic models
|
||||
# This ensures the models are the single source of truth for descriptions
|
||||
get_retrieve_batch_request = create_path_dependency(RetrieveBatchRequest)
|
||||
get_cancel_batch_request = create_path_dependency(CancelBatchRequest)
|
||||
|
||||
|
||||
# Automatically generate dependency function from Pydantic model
|
||||
# This ensures the model is the single source of truth for descriptions and defaults
|
||||
get_list_batches_request = create_query_dependency(ListBatchesRequest)
|
||||
|
||||
|
||||
def create_router(impl: Batches) -> APIRouter:
|
||||
"""Create a FastAPI router for the Batches API.
|
||||
|
||||
Args:
|
||||
impl: The Batches implementation instance
|
||||
|
||||
Returns:
|
||||
APIRouter configured for the Batches API
|
||||
"""
|
||||
router = APIRouter(
|
||||
prefix=f"/{LLAMA_STACK_API_V1}",
|
||||
tags=["Batches"],
|
||||
responses=standard_responses,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/batches",
|
||||
response_model=BatchObject,
|
||||
summary="Create a new batch for processing multiple API requests.",
|
||||
description="Create a new batch for processing multiple API requests.",
|
||||
responses={
|
||||
200: {"description": "The created batch object."},
|
||||
409: {"description": "Conflict: The idempotency key was previously used with different parameters."},
|
||||
},
|
||||
)
|
||||
async def create_batch(
|
||||
request: Annotated[CreateBatchRequest, Body(...)],
|
||||
) -> BatchObject:
|
||||
return await impl.create_batch(request)
|
||||
|
||||
@router.get(
|
||||
"/batches/{batch_id}",
|
||||
response_model=BatchObject,
|
||||
summary="Retrieve information about a specific batch.",
|
||||
description="Retrieve information about a specific batch.",
|
||||
responses={
|
||||
200: {"description": "The batch object."},
|
||||
},
|
||||
)
|
||||
async def retrieve_batch(
|
||||
request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)],
|
||||
) -> BatchObject:
|
||||
return await impl.retrieve_batch(request)
|
||||
|
||||
@router.post(
|
||||
"/batches/{batch_id}/cancel",
|
||||
response_model=BatchObject,
|
||||
summary="Cancel a batch that is in progress.",
|
||||
description="Cancel a batch that is in progress.",
|
||||
responses={
|
||||
200: {"description": "The updated batch object."},
|
||||
},
|
||||
)
|
||||
async def cancel_batch(
|
||||
request: Annotated[CancelBatchRequest, Depends(get_cancel_batch_request)],
|
||||
) -> BatchObject:
|
||||
return await impl.cancel_batch(request)
|
||||
|
||||
@router.get(
|
||||
"/batches",
|
||||
response_model=ListBatchesResponse,
|
||||
summary="List all batches for the current user.",
|
||||
description="List all batches for the current user.",
|
||||
responses={
|
||||
200: {"description": "A list of batch objects."},
|
||||
},
|
||||
)
|
||||
async def list_batches(
|
||||
request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)],
|
||||
) -> ListBatchesResponse:
|
||||
return await impl.list_batches(request)
|
||||
|
||||
return router
|
||||
82
src/llama_stack_api/batches/models.py
Normal file
82
src/llama_stack_api/batches/models.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Pydantic models for Batches API requests and responses.
|
||||
|
||||
This module defines the request and response models for the Batches API
|
||||
using Pydantic with Field descriptions for OpenAPI schema generation.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack_api.schema_utils import json_schema_type
|
||||
|
||||
try:
|
||||
from openai.types import Batch as BatchObject
|
||||
except ImportError as e:
|
||||
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CreateBatchRequest(BaseModel):
|
||||
"""Request model for creating a batch."""
|
||||
|
||||
input_file_id: str = Field(..., description="The ID of an uploaded file containing requests for the batch.")
|
||||
endpoint: str = Field(..., description="The endpoint to be used for all requests in the batch.")
|
||||
completion_window: Literal["24h"] = Field(
|
||||
..., description="The time window within which the batch should be processed."
|
||||
)
|
||||
metadata: dict[str, str] | None = Field(default=None, description="Optional metadata for the batch.")
|
||||
idempotency_key: str | None = Field(
|
||||
default=None, description="Optional idempotency key. When provided, enables idempotent behavior."
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesRequest(BaseModel):
|
||||
"""Request model for listing batches."""
|
||||
|
||||
after: str | None = Field(
|
||||
default=None, description="Optional cursor for pagination. Returns batches after this ID."
|
||||
)
|
||||
limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RetrieveBatchRequest(BaseModel):
|
||||
"""Request model for retrieving a batch."""
|
||||
|
||||
batch_id: str = Field(..., description="The ID of the batch to retrieve.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CancelBatchRequest(BaseModel):
|
||||
"""Request model for canceling a batch."""
|
||||
|
||||
batch_id: str = Field(..., description="The ID of the batch to cancel.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesResponse(BaseModel):
|
||||
"""Response containing a list of batch objects."""
|
||||
|
||||
object: Literal["list"] = "list"
|
||||
data: list[BatchObject] = Field(..., description="List of batch objects")
|
||||
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
|
||||
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
|
||||
has_more: bool = Field(default=False, description="Whether there are more batches available")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CreateBatchRequest",
|
||||
"ListBatchesRequest",
|
||||
"RetrieveBatchRequest",
|
||||
"CancelBatchRequest",
|
||||
"ListBatchesResponse",
|
||||
"BatchObject",
|
||||
]
|
||||
|
|
@ -24,6 +24,7 @@ classifiers = [
|
|||
"Topic :: Scientific/Engineering :: Information Analysis",
|
||||
]
|
||||
dependencies = [
|
||||
"fastapi>=0.115.0,<1.0",
|
||||
"pydantic>=2.11.9",
|
||||
"jsonschema",
|
||||
"opentelemetry-sdk>=1.30.0",
|
||||
|
|
|
|||
155
src/llama_stack_api/router_utils.py
Normal file
155
src/llama_stack_api/router_utils.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Utilities for creating FastAPI routers with standard error responses.
|
||||
|
||||
This module provides standard error response definitions for FastAPI routers.
|
||||
These responses use OpenAPI $ref references to component responses defined
|
||||
in the OpenAPI specification.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from typing import Annotated, Any, TypeVar
|
||||
|
||||
from fastapi import Path, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
standard_responses: dict[int | str, dict[str, Any]] = {
|
||||
400: {"$ref": "#/components/responses/BadRequest400"},
|
||||
429: {"$ref": "#/components/responses/TooManyRequests429"},
|
||||
500: {"$ref": "#/components/responses/InternalServerError500"},
|
||||
"default": {"$ref": "#/components/responses/DefaultError"},
|
||||
}
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
def create_query_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]:
|
||||
"""Create a FastAPI dependency function from a Pydantic model for query parameters.
|
||||
|
||||
FastAPI does not natively support using Pydantic models as query parameters
|
||||
without a dependency function. Using a dependency function typically leads to
|
||||
duplication: field types, default values, and descriptions must be repeated in
|
||||
`Query(...)` annotations even though they already exist in the Pydantic model.
|
||||
|
||||
This function automatically generates a dependency function that extracts query parameters
|
||||
from the request and constructs an instance of the Pydantic model. The descriptions and
|
||||
defaults are automatically extracted from the model's Field definitions, making the model
|
||||
the single source of truth.
|
||||
|
||||
Args:
|
||||
model_class: The Pydantic model class to create a dependency for
|
||||
|
||||
Returns:
|
||||
A dependency function that can be used with FastAPI's Depends()
|
||||
```
|
||||
"""
|
||||
# Build function signature dynamically from model fields
|
||||
annotations: dict[str, Any] = {}
|
||||
defaults: dict[str, Any] = {}
|
||||
|
||||
for field_name, field_info in model_class.model_fields.items():
|
||||
# Extract description from Field
|
||||
description = field_info.description
|
||||
|
||||
# Create Query annotation with description from model
|
||||
query_annotation = Query(description=description) if description else Query()
|
||||
|
||||
# Create Annotated type with Query
|
||||
field_type = field_info.annotation
|
||||
annotations[field_name] = Annotated[field_type, query_annotation]
|
||||
|
||||
# Set default value from model
|
||||
if field_info.default is not inspect.Parameter.empty:
|
||||
defaults[field_name] = field_info.default
|
||||
|
||||
# Create the dependency function dynamically
|
||||
def dependency_func(**kwargs: Any) -> T:
|
||||
return model_class(**kwargs)
|
||||
|
||||
# Set function signature
|
||||
sig_params = []
|
||||
for field_name, field_type in annotations.items():
|
||||
default = defaults.get(field_name, inspect.Parameter.empty)
|
||||
param = inspect.Parameter(
|
||||
field_name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
default=default,
|
||||
annotation=field_type,
|
||||
)
|
||||
sig_params.append(param)
|
||||
|
||||
# These attributes are set dynamically at runtime. While mypy can't verify them statically,
|
||||
# they are standard Python function attributes that exist on all callable objects at runtime.
|
||||
# Setting them allows FastAPI to properly introspect the function signature for dependency injection.
|
||||
dependency_func.__signature__ = inspect.Signature(sig_params) # type: ignore[attr-defined]
|
||||
dependency_func.__annotations__ = annotations # type: ignore[attr-defined]
|
||||
dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" # type: ignore[attr-defined]
|
||||
|
||||
return dependency_func
|
||||
|
||||
|
||||
def create_path_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]:
|
||||
"""Create a FastAPI dependency function from a Pydantic model for path parameters.
|
||||
|
||||
FastAPI requires path parameters to be explicitly annotated with `Path()`. When using
|
||||
a Pydantic model that contains path parameters, you typically need a dependency function
|
||||
that extracts the path parameter and constructs the model. This leads to duplication:
|
||||
the parameter name, type, and description must be repeated in `Path(...)` annotations
|
||||
even though they already exist in the Pydantic model.
|
||||
|
||||
This function automatically generates a dependency function that extracts path parameters
|
||||
from the request and constructs an instance of the Pydantic model. The descriptions are
|
||||
automatically extracted from the model's Field definitions, making the model the single
|
||||
source of truth.
|
||||
|
||||
Args:
|
||||
model_class: The Pydantic model class to create a dependency for. The model should
|
||||
have exactly one field that represents the path parameter.
|
||||
|
||||
Returns:
|
||||
A dependency function that can be used with FastAPI's Depends()
|
||||
```
|
||||
"""
|
||||
# Get the single field from the model (path parameter models typically have one field)
|
||||
if len(model_class.model_fields) != 1:
|
||||
raise ValueError(
|
||||
f"Path parameter model {model_class.__name__} must have exactly one field, "
|
||||
f"but has {len(model_class.model_fields)} fields"
|
||||
)
|
||||
|
||||
field_name, field_info = next(iter(model_class.model_fields.items()))
|
||||
|
||||
# Extract description from Field
|
||||
description = field_info.description
|
||||
|
||||
# Create Path annotation with description from model
|
||||
path_annotation = Path(description=description) if description else Path()
|
||||
|
||||
# Create Annotated type with Path
|
||||
field_type = field_info.annotation
|
||||
annotations: dict[str, Any] = {field_name: Annotated[field_type, path_annotation]}
|
||||
|
||||
# Create the dependency function dynamically
|
||||
def dependency_func(**kwargs: Any) -> T:
|
||||
return model_class(**kwargs)
|
||||
|
||||
# Set function signature
|
||||
param = inspect.Parameter(
|
||||
field_name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
annotation=annotations[field_name],
|
||||
)
|
||||
|
||||
# These attributes are set dynamically at runtime. While mypy can't verify them statically,
|
||||
# they are standard Python function attributes that exist on all callable objects at runtime.
|
||||
# Setting them allows FastAPI to properly introspect the function signature for dependency injection.
|
||||
dependency_func.__signature__ = inspect.Signature([param]) # type: ignore[attr-defined]
|
||||
dependency_func.__annotations__ = annotations # type: ignore[attr-defined]
|
||||
dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" # type: ignore[attr-defined]
|
||||
|
||||
return dependency_func
|
||||
Loading…
Add table
Add a link
Reference in a new issue