diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index be941f652..c477ae32c 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -48,14 +48,18 @@ paths: anyOf: - type: string - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. title: After + description: Optional cursor for pagination. Returns batches after this ID. - name: limit in: query required: false schema: type: integer + description: Maximum number of batches to return. Defaults to 20. default: 20 title: Limit + description: Maximum number of batches to return. Defaults to 20. post: responses: '200': @@ -12690,6 +12694,22 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 94b1a69a7..2e517da70 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -9531,6 +9531,22 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/experimental-llama-stack-spec.yaml b/docs/static/experimental-llama-stack-spec.yaml index dfd354544..21112924c 100644 --- a/docs/static/experimental-llama-stack-spec.yaml +++ b/docs/static/experimental-llama-stack-spec.yaml @@ -8544,6 +8544,22 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index a736fc8f9..5335e69b8 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -46,14 +46,18 @@ paths: anyOf: - type: string - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. title: After + description: Optional cursor for pagination. Returns batches after this ID. - name: limit in: query required: false schema: type: integer + description: Maximum number of batches to return. Defaults to 20. default: 20 title: Limit + description: Maximum number of batches to return. Defaults to 20. post: responses: '200': @@ -11419,6 +11423,22 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index be941f652..c477ae32c 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -48,14 +48,18 @@ paths: anyOf: - type: string - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. title: After + description: Optional cursor for pagination. Returns batches after this ID. - name: limit in: query required: false schema: type: integer + description: Maximum number of batches to return. Defaults to 20. default: 20 title: Limit + description: Maximum number of batches to return. Defaults to 20. post: responses: '200': @@ -12690,6 +12694,22 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/src/llama_stack/providers/inline/batches/reference/batches.py b/src/llama_stack/providers/inline/batches/reference/batches.py index 73727799d..aaa105f28 100644 --- a/src/llama_stack/providers/inline/batches/reference/batches.py +++ b/src/llama_stack/providers/inline/batches/reference/batches.py @@ -11,7 +11,7 @@ import json import time import uuid from io import BytesIO -from typing import Any, Literal +from typing import Any from openai.types.batch import BatchError, Errors from pydantic import BaseModel @@ -38,6 +38,7 @@ from llama_stack_api import ( OpenAIUserMessageParam, ResourceNotFoundError, ) +from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest from .config import ReferenceBatchesImplConfig @@ -140,11 +141,7 @@ class ReferenceBatchesImpl(Batches): # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions async def create_batch( self, - input_file_id: str, - endpoint: str, - completion_window: Literal["24h"], - metadata: dict[str, str] | None = None, - idempotency_key: str | None = None, + request: CreateBatchRequest, ) -> BatchObject: """ Create a new batch for processing multiple API requests. @@ -185,14 +182,14 @@ class ReferenceBatchesImpl(Batches): # TODO: set expiration time for garbage collection - if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]: + if request.endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]: raise ValueError( - f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint", + f"Invalid endpoint: {request.endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint", ) - if completion_window != "24h": + if request.completion_window != "24h": raise ValueError( - f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window", + f"Invalid completion_window: {request.completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window", ) batch_id = f"batch_{uuid.uuid4().hex[:16]}" @@ -200,8 +197,8 @@ class ReferenceBatchesImpl(Batches): # For idempotent requests, use the idempotency key for the batch ID # This ensures the same key always maps to the same batch ID, # allowing us to detect parameter conflicts - if idempotency_key is not None: - hash_input = idempotency_key.encode("utf-8") + if request.idempotency_key is not None: + hash_input = request.idempotency_key.encode("utf-8") hash_digest = hashlib.sha256(hash_input).hexdigest()[:24] batch_id = f"batch_{hash_digest}" @@ -209,13 +206,13 @@ class ReferenceBatchesImpl(Batches): existing_batch = await self.retrieve_batch(batch_id) if ( - existing_batch.input_file_id != input_file_id - or existing_batch.endpoint != endpoint - or existing_batch.completion_window != completion_window - or existing_batch.metadata != metadata + existing_batch.input_file_id != request.input_file_id + or existing_batch.endpoint != request.endpoint + or existing_batch.completion_window != request.completion_window + or existing_batch.metadata != request.metadata ): raise ConflictError( - f"Idempotency key '{idempotency_key}' was previously used with different parameters. " + f"Idempotency key '{request.idempotency_key}' was previously used with different parameters. " "Either use a new idempotency key or ensure all parameters match the original request." ) @@ -230,12 +227,12 @@ class ReferenceBatchesImpl(Batches): batch = BatchObject( id=batch_id, object="batch", - endpoint=endpoint, - input_file_id=input_file_id, - completion_window=completion_window, + endpoint=request.endpoint, + input_file_id=request.input_file_id, + completion_window=request.completion_window, status="validating", created_at=current_time, - metadata=metadata, + metadata=request.metadata, ) await self.kvstore.set(f"batch:{batch_id}", batch.to_json()) @@ -267,8 +264,7 @@ class ReferenceBatchesImpl(Batches): async def list_batches( self, - after: str | None = None, - limit: int = 20, + request: ListBatchesRequest, ) -> ListBatchesResponse: """ List all batches, eventually only for the current user. @@ -285,14 +281,14 @@ class ReferenceBatchesImpl(Batches): batches.sort(key=lambda b: b.created_at, reverse=True) start_idx = 0 - if after: + if request.after: for i, batch in enumerate(batches): - if batch.id == after: + if batch.id == request.after: start_idx = i + 1 break - page_batches = batches[start_idx : start_idx + limit] - has_more = (start_idx + limit) < len(batches) + page_batches = batches[start_idx : start_idx + request.limit] + has_more = (start_idx + request.limit) < len(batches) first_id = page_batches[0].id if page_batches else None last_id = page_batches[-1].id if page_batches else None diff --git a/src/llama_stack_api/batches/__init__.py b/src/llama_stack_api/batches/__init__.py index 6f778de8e..3d301598c 100644 --- a/src/llama_stack_api/batches/__init__.py +++ b/src/llama_stack_api/batches/__init__.py @@ -11,7 +11,7 @@ Pydantic models are defined in llama_stack_api.batches.models. The FastAPI router is defined in llama_stack_api.batches.routes. """ -from typing import Literal, Protocol, runtime_checkable +from typing import Protocol, runtime_checkable try: from openai.types import Batch as BatchObject @@ -19,7 +19,7 @@ except ImportError as e: raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e # Import models for re-export -from llama_stack_api.batches.models import ListBatchesResponse +from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest, ListBatchesResponse @runtime_checkable @@ -39,11 +39,7 @@ class Batches(Protocol): async def create_batch( self, - input_file_id: str, - endpoint: str, - completion_window: Literal["24h"], - metadata: dict[str, str] | None = None, - idempotency_key: str | None = None, + request: CreateBatchRequest, ) -> BatchObject: ... async def retrieve_batch(self, batch_id: str) -> BatchObject: ... @@ -52,9 +48,8 @@ class Batches(Protocol): async def list_batches( self, - after: str | None = None, - limit: int = 20, + request: ListBatchesRequest, ) -> ListBatchesResponse: ... -__all__ = ["Batches", "BatchObject", "ListBatchesResponse"] +__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse"] diff --git a/src/llama_stack_api/batches/models.py b/src/llama_stack_api/batches/models.py index fe449280d..bb6d7e3d0 100644 --- a/src/llama_stack_api/batches/models.py +++ b/src/llama_stack_api/batches/models.py @@ -37,6 +37,16 @@ class CreateBatchRequest(BaseModel): ) +@json_schema_type +class ListBatchesRequest(BaseModel): + """Request model for listing batches.""" + + after: str | None = Field( + default=None, description="Optional cursor for pagination. Returns batches after this ID." + ) + limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.") + + @json_schema_type class ListBatchesResponse(BaseModel): """Response containing a list of batch objects.""" @@ -48,4 +58,4 @@ class ListBatchesResponse(BaseModel): has_more: bool = Field(default=False, description="Whether there are more batches available") -__all__ = ["CreateBatchRequest", "ListBatchesResponse", "BatchObject"] +__all__ = ["CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse", "BatchObject"] diff --git a/src/llama_stack_api/batches/routes.py b/src/llama_stack_api/batches/routes.py index e8b6aaf41..adbc10be5 100644 --- a/src/llama_stack_api/batches/routes.py +++ b/src/llama_stack_api/batches/routes.py @@ -14,10 +14,10 @@ all API-related code together. from collections.abc import Callable from typing import Annotated -from fastapi import APIRouter, Body, Depends +from fastapi import APIRouter, Body, Depends, Query from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse -from llama_stack_api.batches.models import CreateBatchRequest +from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest from llama_stack_api.datatypes import Api from llama_stack_api.router_utils import standard_responses from llama_stack_api.version import LLAMA_STACK_API_V1 @@ -56,13 +56,7 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter: request: Annotated[CreateBatchRequest, Body(...)], svc: Annotated[Batches, Depends(get_batch_service)], ) -> BatchObject: - return await svc.create_batch( - input_file_id=request.input_file_id, - endpoint=request.endpoint, - completion_window=request.completion_window, - metadata=request.metadata, - idempotency_key=request.idempotency_key, - ) + return await svc.create_batch(request) @router.get( "/batches/{batch_id}", @@ -94,6 +88,15 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter: ) -> BatchObject: return await svc.cancel_batch(batch_id) + def get_list_batches_request( + after: Annotated[ + str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.") + ] = None, + limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20, + ) -> ListBatchesRequest: + """Dependency function to create ListBatchesRequest from query parameters.""" + return ListBatchesRequest(after=after, limit=limit) + @router.get( "/batches", response_model=ListBatchesResponse, @@ -104,10 +107,9 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter: }, ) async def list_batches( + request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)], svc: Annotated[Batches, Depends(get_batch_service)], - after: str | None = None, - limit: int = 20, ) -> ListBatchesResponse: - return await svc.list_batches(after=after, limit=limit) + return await svc.list_batches(request) return router