chore: same as previous commit but for more fields

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-20 16:12:52 +01:00
parent 30cab02083
commit 20030429e7
No known key found for this signature in database
9 changed files with 210 additions and 25 deletions

View file

@ -125,7 +125,9 @@ paths:
required: true required: true
schema: schema:
type: string type: string
description: The ID of the batch to retrieve.
title: Batch Id title: Batch Id
description: The ID of the batch to retrieve.
/v1/batches/{batch_id}/cancel: /v1/batches/{batch_id}/cancel:
post: post:
responses: responses:
@ -158,7 +160,9 @@ paths:
required: true required: true
schema: schema:
type: string type: string
description: The ID of the batch to cancel.
title: Batch Id title: Batch Id
description: The ID of the batch to cancel.
/v1/chat/completions: /v1/chat/completions:
get: get:
responses: responses:
@ -12710,6 +12714,28 @@ components:
type: integer type: integer
title: ListBatchesRequest title: ListBatchesRequest
type: object type: object
RetrieveBatchRequest:
description: Request model for retrieving a batch.
properties:
batch_id:
description: The ID of the batch to retrieve.
title: Batch Id
type: string
required:
- batch_id
title: RetrieveBatchRequest
type: object
CancelBatchRequest:
description: Request model for canceling a batch.
properties:
batch_id:
description: The ID of the batch to cancel.
title: Batch Id
type: string
required:
- batch_id
title: CancelBatchRequest
type: object
DialogType: DialogType:
description: Parameter type for dialog data with semantic output labels. description: Parameter type for dialog data with semantic output labels.
properties: properties:

View file

@ -9547,6 +9547,28 @@ components:
type: integer type: integer
title: ListBatchesRequest title: ListBatchesRequest
type: object type: object
RetrieveBatchRequest:
description: Request model for retrieving a batch.
properties:
batch_id:
description: The ID of the batch to retrieve.
title: Batch Id
type: string
required:
- batch_id
title: RetrieveBatchRequest
type: object
CancelBatchRequest:
description: Request model for canceling a batch.
properties:
batch_id:
description: The ID of the batch to cancel.
title: Batch Id
type: string
required:
- batch_id
title: CancelBatchRequest
type: object
DialogType: DialogType:
description: Parameter type for dialog data with semantic output labels. description: Parameter type for dialog data with semantic output labels.
properties: properties:

View file

@ -8560,6 +8560,28 @@ components:
type: integer type: integer
title: ListBatchesRequest title: ListBatchesRequest
type: object type: object
RetrieveBatchRequest:
description: Request model for retrieving a batch.
properties:
batch_id:
description: The ID of the batch to retrieve.
title: Batch Id
type: string
required:
- batch_id
title: RetrieveBatchRequest
type: object
CancelBatchRequest:
description: Request model for canceling a batch.
properties:
batch_id:
description: The ID of the batch to cancel.
title: Batch Id
type: string
required:
- batch_id
title: CancelBatchRequest
type: object
DialogType: DialogType:
description: Parameter type for dialog data with semantic output labels. description: Parameter type for dialog data with semantic output labels.
properties: properties:

View file

@ -123,7 +123,9 @@ paths:
required: true required: true
schema: schema:
type: string type: string
description: The ID of the batch to retrieve.
title: Batch Id title: Batch Id
description: The ID of the batch to retrieve.
/v1/batches/{batch_id}/cancel: /v1/batches/{batch_id}/cancel:
post: post:
responses: responses:
@ -156,7 +158,9 @@ paths:
required: true required: true
schema: schema:
type: string type: string
description: The ID of the batch to cancel.
title: Batch Id title: Batch Id
description: The ID of the batch to cancel.
/v1/chat/completions: /v1/chat/completions:
get: get:
responses: responses:
@ -11439,6 +11443,28 @@ components:
type: integer type: integer
title: ListBatchesRequest title: ListBatchesRequest
type: object type: object
RetrieveBatchRequest:
description: Request model for retrieving a batch.
properties:
batch_id:
description: The ID of the batch to retrieve.
title: Batch Id
type: string
required:
- batch_id
title: RetrieveBatchRequest
type: object
CancelBatchRequest:
description: Request model for canceling a batch.
properties:
batch_id:
description: The ID of the batch to cancel.
title: Batch Id
type: string
required:
- batch_id
title: CancelBatchRequest
type: object
DialogType: DialogType:
description: Parameter type for dialog data with semantic output labels. description: Parameter type for dialog data with semantic output labels.
properties: properties:

View file

@ -125,7 +125,9 @@ paths:
required: true required: true
schema: schema:
type: string type: string
description: The ID of the batch to retrieve.
title: Batch Id title: Batch Id
description: The ID of the batch to retrieve.
/v1/batches/{batch_id}/cancel: /v1/batches/{batch_id}/cancel:
post: post:
responses: responses:
@ -158,7 +160,9 @@ paths:
required: true required: true
schema: schema:
type: string type: string
description: The ID of the batch to cancel.
title: Batch Id title: Batch Id
description: The ID of the batch to cancel.
/v1/chat/completions: /v1/chat/completions:
get: get:
responses: responses:
@ -12710,6 +12714,28 @@ components:
type: integer type: integer
title: ListBatchesRequest title: ListBatchesRequest
type: object type: object
RetrieveBatchRequest:
description: Request model for retrieving a batch.
properties:
batch_id:
description: The ID of the batch to retrieve.
title: Batch Id
type: string
required:
- batch_id
title: RetrieveBatchRequest
type: object
CancelBatchRequest:
description: Request model for canceling a batch.
properties:
batch_id:
description: The ID of the batch to cancel.
title: Batch Id
type: string
required:
- batch_id
title: CancelBatchRequest
type: object
DialogType: DialogType:
description: Parameter type for dialog data with semantic output labels. description: Parameter type for dialog data with semantic output labels.
properties: properties:

View file

@ -38,7 +38,12 @@ from llama_stack_api import (
OpenAIUserMessageParam, OpenAIUserMessageParam,
ResourceNotFoundError, ResourceNotFoundError,
) )
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest from llama_stack_api.batches.models import (
CancelBatchRequest,
CreateBatchRequest,
ListBatchesRequest,
RetrieveBatchRequest,
)
from .config import ReferenceBatchesImplConfig from .config import ReferenceBatchesImplConfig
@ -203,7 +208,7 @@ class ReferenceBatchesImpl(Batches):
batch_id = f"batch_{hash_digest}" batch_id = f"batch_{hash_digest}"
try: try:
existing_batch = await self.retrieve_batch(batch_id) existing_batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
if ( if (
existing_batch.input_file_id != request.input_file_id existing_batch.input_file_id != request.input_file_id
@ -244,23 +249,23 @@ class ReferenceBatchesImpl(Batches):
return batch return batch
async def cancel_batch(self, batch_id: str) -> BatchObject: async def cancel_batch(self, request: CancelBatchRequest) -> BatchObject:
"""Cancel a batch that is in progress.""" """Cancel a batch that is in progress."""
batch = await self.retrieve_batch(batch_id) batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=request.batch_id))
if batch.status in ["cancelled", "cancelling"]: if batch.status in ["cancelled", "cancelling"]:
return batch return batch
if batch.status in ["completed", "failed", "expired"]: if batch.status in ["completed", "failed", "expired"]:
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'") raise ConflictError(f"Cannot cancel batch '{request.batch_id}' with status '{batch.status}'")
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time())) await self._update_batch(request.batch_id, status="cancelling", cancelling_at=int(time.time()))
if batch_id in self._processing_tasks: if request.batch_id in self._processing_tasks:
self._processing_tasks[batch_id].cancel() self._processing_tasks[request.batch_id].cancel()
# note: task removal and status="cancelled" handled in finally block of _process_batch # note: task removal and status="cancelled" handled in finally block of _process_batch
return await self.retrieve_batch(batch_id) return await self.retrieve_batch(RetrieveBatchRequest(batch_id=request.batch_id))
async def list_batches( async def list_batches(
self, self,
@ -300,11 +305,11 @@ class ReferenceBatchesImpl(Batches):
has_more=has_more, has_more=has_more,
) )
async def retrieve_batch(self, batch_id: str) -> BatchObject: async def retrieve_batch(self, request: RetrieveBatchRequest) -> BatchObject:
"""Retrieve information about a specific batch.""" """Retrieve information about a specific batch."""
batch_data = await self.kvstore.get(f"batch:{batch_id}") batch_data = await self.kvstore.get(f"batch:{request.batch_id}")
if not batch_data: if not batch_data:
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()") raise ResourceNotFoundError(request.batch_id, "Batch", "batches.list()")
return BatchObject.model_validate_json(batch_data) return BatchObject.model_validate_json(batch_data)
@ -312,7 +317,7 @@ class ReferenceBatchesImpl(Batches):
"""Update batch fields in kvstore.""" """Update batch fields in kvstore."""
async with self._update_batch_lock: async with self._update_batch_lock:
try: try:
batch = await self.retrieve_batch(batch_id) batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
# batch processing is async. once cancelling, only allow "cancelled" status updates # batch processing is async. once cancelling, only allow "cancelled" status updates
if batch.status == "cancelling" and updates.get("status") != "cancelled": if batch.status == "cancelling" and updates.get("status") != "cancelled":
@ -532,7 +537,7 @@ class ReferenceBatchesImpl(Batches):
async def _process_batch_impl(self, batch_id: str) -> None: async def _process_batch_impl(self, batch_id: str) -> None:
"""Implementation of batch processing logic.""" """Implementation of batch processing logic."""
errors: list[BatchError] = [] errors: list[BatchError] = []
batch = await self.retrieve_batch(batch_id) batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
errors, requests = await self._validate_input(batch) errors, requests = await self._validate_input(batch)
if errors: if errors:

View file

@ -19,7 +19,13 @@ except ImportError as e:
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
# Import models for re-export # Import models for re-export
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest, ListBatchesResponse from llama_stack_api.batches.models import (
CancelBatchRequest,
CreateBatchRequest,
ListBatchesRequest,
ListBatchesResponse,
RetrieveBatchRequest,
)
@runtime_checkable @runtime_checkable
@ -42,9 +48,15 @@ class Batches(Protocol):
request: CreateBatchRequest, request: CreateBatchRequest,
) -> BatchObject: ... ) -> BatchObject: ...
async def retrieve_batch(self, batch_id: str) -> BatchObject: ... async def retrieve_batch(
self,
request: RetrieveBatchRequest,
) -> BatchObject: ...
async def cancel_batch(self, batch_id: str) -> BatchObject: ... async def cancel_batch(
self,
request: CancelBatchRequest,
) -> BatchObject: ...
async def list_batches( async def list_batches(
self, self,
@ -52,4 +64,12 @@ class Batches(Protocol):
) -> ListBatchesResponse: ... ) -> ListBatchesResponse: ...
__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse"] __all__ = [
"Batches",
"BatchObject",
"CreateBatchRequest",
"ListBatchesRequest",
"RetrieveBatchRequest",
"CancelBatchRequest",
"ListBatchesResponse",
]

View file

@ -47,6 +47,20 @@ class ListBatchesRequest(BaseModel):
limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.") limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.")
@json_schema_type
class RetrieveBatchRequest(BaseModel):
"""Request model for retrieving a batch."""
batch_id: str = Field(..., description="The ID of the batch to retrieve.")
@json_schema_type
class CancelBatchRequest(BaseModel):
"""Request model for canceling a batch."""
batch_id: str = Field(..., description="The ID of the batch to cancel.")
@json_schema_type @json_schema_type
class ListBatchesResponse(BaseModel): class ListBatchesResponse(BaseModel):
"""Response containing a list of batch objects.""" """Response containing a list of batch objects."""
@ -58,4 +72,11 @@ class ListBatchesResponse(BaseModel):
has_more: bool = Field(default=False, description="Whether there are more batches available") has_more: bool = Field(default=False, description="Whether there are more batches available")
__all__ = ["CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse", "BatchObject"] __all__ = [
"CreateBatchRequest",
"ListBatchesRequest",
"RetrieveBatchRequest",
"CancelBatchRequest",
"ListBatchesResponse",
"BatchObject",
]

View file

@ -14,10 +14,15 @@ all API-related code together.
from collections.abc import Callable from collections.abc import Callable
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Body, Depends, Query from fastapi import APIRouter, Body, Depends, Path, Query
from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest from llama_stack_api.batches.models import (
CancelBatchRequest,
CreateBatchRequest,
ListBatchesRequest,
RetrieveBatchRequest,
)
from llama_stack_api.datatypes import Api from llama_stack_api.datatypes import Api
from llama_stack_api.router_utils import standard_responses from llama_stack_api.router_utils import standard_responses
from llama_stack_api.version import LLAMA_STACK_API_V1 from llama_stack_api.version import LLAMA_STACK_API_V1
@ -58,6 +63,12 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
) -> BatchObject: ) -> BatchObject:
return await svc.create_batch(request) return await svc.create_batch(request)
def get_retrieve_batch_request(
batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")],
) -> RetrieveBatchRequest:
"""Dependency function to create RetrieveBatchRequest from path parameter."""
return RetrieveBatchRequest(batch_id=batch_id)
@router.get( @router.get(
"/batches/{batch_id}", "/batches/{batch_id}",
response_model=BatchObject, response_model=BatchObject,
@ -68,10 +79,16 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
}, },
) )
async def retrieve_batch( async def retrieve_batch(
batch_id: str, request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)],
svc: Annotated[Batches, Depends(get_batch_service)], svc: Annotated[Batches, Depends(get_batch_service)],
) -> BatchObject: ) -> BatchObject:
return await svc.retrieve_batch(batch_id) return await svc.retrieve_batch(request)
def get_cancel_batch_request(
batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")],
) -> CancelBatchRequest:
"""Dependency function to create CancelBatchRequest from path parameter."""
return CancelBatchRequest(batch_id=batch_id)
@router.post( @router.post(
"/batches/{batch_id}/cancel", "/batches/{batch_id}/cancel",
@ -83,10 +100,10 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
}, },
) )
async def cancel_batch( async def cancel_batch(
batch_id: str, request: Annotated[CancelBatchRequest, Depends(get_cancel_batch_request)],
svc: Annotated[Batches, Depends(get_batch_service)], svc: Annotated[Batches, Depends(get_batch_service)],
) -> BatchObject: ) -> BatchObject:
return await svc.cancel_batch(batch_id) return await svc.cancel_batch(request)
def get_list_batches_request( def get_list_batches_request(
after: Annotated[ after: Annotated[