chore: same as previous commit but for more fields

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-20 16:12:52 +01:00
parent 30cab02083
commit 20030429e7
No known key found for this signature in database
9 changed files with 210 additions and 25 deletions

View file

@ -125,7 +125,9 @@ paths:
required: true
schema:
type: string
description: The ID of the batch to retrieve.
title: Batch Id
description: The ID of the batch to retrieve.
/v1/batches/{batch_id}/cancel:
post:
responses:
@ -158,7 +160,9 @@ paths:
required: true
schema:
type: string
description: The ID of the batch to cancel.
title: Batch Id
description: The ID of the batch to cancel.
/v1/chat/completions:
get:
responses:
@ -12710,6 +12714,28 @@ components:
type: integer
title: ListBatchesRequest
type: object
RetrieveBatchRequest:
description: Request model for retrieving a batch.
properties:
batch_id:
description: The ID of the batch to retrieve.
title: Batch Id
type: string
required:
- batch_id
title: RetrieveBatchRequest
type: object
CancelBatchRequest:
description: Request model for canceling a batch.
properties:
batch_id:
description: The ID of the batch to cancel.
title: Batch Id
type: string
required:
- batch_id
title: CancelBatchRequest
type: object
DialogType:
description: Parameter type for dialog data with semantic output labels.
properties:

View file

@ -9547,6 +9547,28 @@ components:
type: integer
title: ListBatchesRequest
type: object
RetrieveBatchRequest:
description: Request model for retrieving a batch.
properties:
batch_id:
description: The ID of the batch to retrieve.
title: Batch Id
type: string
required:
- batch_id
title: RetrieveBatchRequest
type: object
CancelBatchRequest:
description: Request model for canceling a batch.
properties:
batch_id:
description: The ID of the batch to cancel.
title: Batch Id
type: string
required:
- batch_id
title: CancelBatchRequest
type: object
DialogType:
description: Parameter type for dialog data with semantic output labels.
properties:

View file

@ -8560,6 +8560,28 @@ components:
type: integer
title: ListBatchesRequest
type: object
RetrieveBatchRequest:
description: Request model for retrieving a batch.
properties:
batch_id:
description: The ID of the batch to retrieve.
title: Batch Id
type: string
required:
- batch_id
title: RetrieveBatchRequest
type: object
CancelBatchRequest:
description: Request model for canceling a batch.
properties:
batch_id:
description: The ID of the batch to cancel.
title: Batch Id
type: string
required:
- batch_id
title: CancelBatchRequest
type: object
DialogType:
description: Parameter type for dialog data with semantic output labels.
properties:

View file

@ -123,7 +123,9 @@ paths:
required: true
schema:
type: string
description: The ID of the batch to retrieve.
title: Batch Id
description: The ID of the batch to retrieve.
/v1/batches/{batch_id}/cancel:
post:
responses:
@ -156,7 +158,9 @@ paths:
required: true
schema:
type: string
description: The ID of the batch to cancel.
title: Batch Id
description: The ID of the batch to cancel.
/v1/chat/completions:
get:
responses:
@ -11439,6 +11443,28 @@ components:
type: integer
title: ListBatchesRequest
type: object
RetrieveBatchRequest:
description: Request model for retrieving a batch.
properties:
batch_id:
description: The ID of the batch to retrieve.
title: Batch Id
type: string
required:
- batch_id
title: RetrieveBatchRequest
type: object
CancelBatchRequest:
description: Request model for canceling a batch.
properties:
batch_id:
description: The ID of the batch to cancel.
title: Batch Id
type: string
required:
- batch_id
title: CancelBatchRequest
type: object
DialogType:
description: Parameter type for dialog data with semantic output labels.
properties:

View file

@ -125,7 +125,9 @@ paths:
required: true
schema:
type: string
description: The ID of the batch to retrieve.
title: Batch Id
description: The ID of the batch to retrieve.
/v1/batches/{batch_id}/cancel:
post:
responses:
@ -158,7 +160,9 @@ paths:
required: true
schema:
type: string
description: The ID of the batch to cancel.
title: Batch Id
description: The ID of the batch to cancel.
/v1/chat/completions:
get:
responses:
@ -12710,6 +12714,28 @@ components:
type: integer
title: ListBatchesRequest
type: object
RetrieveBatchRequest:
description: Request model for retrieving a batch.
properties:
batch_id:
description: The ID of the batch to retrieve.
title: Batch Id
type: string
required:
- batch_id
title: RetrieveBatchRequest
type: object
CancelBatchRequest:
description: Request model for canceling a batch.
properties:
batch_id:
description: The ID of the batch to cancel.
title: Batch Id
type: string
required:
- batch_id
title: CancelBatchRequest
type: object
DialogType:
description: Parameter type for dialog data with semantic output labels.
properties:

View file

@ -38,7 +38,12 @@ from llama_stack_api import (
OpenAIUserMessageParam,
ResourceNotFoundError,
)
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest
from llama_stack_api.batches.models import (
CancelBatchRequest,
CreateBatchRequest,
ListBatchesRequest,
RetrieveBatchRequest,
)
from .config import ReferenceBatchesImplConfig
@ -203,7 +208,7 @@ class ReferenceBatchesImpl(Batches):
batch_id = f"batch_{hash_digest}"
try:
existing_batch = await self.retrieve_batch(batch_id)
existing_batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
if (
existing_batch.input_file_id != request.input_file_id
@ -244,23 +249,23 @@ class ReferenceBatchesImpl(Batches):
return batch
async def cancel_batch(self, batch_id: str) -> BatchObject:
async def cancel_batch(self, request: CancelBatchRequest) -> BatchObject:
"""Cancel a batch that is in progress."""
batch = await self.retrieve_batch(batch_id)
batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=request.batch_id))
if batch.status in ["cancelled", "cancelling"]:
return batch
if batch.status in ["completed", "failed", "expired"]:
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
raise ConflictError(f"Cannot cancel batch '{request.batch_id}' with status '{batch.status}'")
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
await self._update_batch(request.batch_id, status="cancelling", cancelling_at=int(time.time()))
if batch_id in self._processing_tasks:
self._processing_tasks[batch_id].cancel()
if request.batch_id in self._processing_tasks:
self._processing_tasks[request.batch_id].cancel()
# note: task removal and status="cancelled" handled in finally block of _process_batch
return await self.retrieve_batch(batch_id)
return await self.retrieve_batch(RetrieveBatchRequest(batch_id=request.batch_id))
async def list_batches(
self,
@ -300,11 +305,11 @@ class ReferenceBatchesImpl(Batches):
has_more=has_more,
)
async def retrieve_batch(self, batch_id: str) -> BatchObject:
async def retrieve_batch(self, request: RetrieveBatchRequest) -> BatchObject:
"""Retrieve information about a specific batch."""
batch_data = await self.kvstore.get(f"batch:{batch_id}")
batch_data = await self.kvstore.get(f"batch:{request.batch_id}")
if not batch_data:
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
raise ResourceNotFoundError(request.batch_id, "Batch", "batches.list()")
return BatchObject.model_validate_json(batch_data)
@ -312,7 +317,7 @@ class ReferenceBatchesImpl(Batches):
"""Update batch fields in kvstore."""
async with self._update_batch_lock:
try:
batch = await self.retrieve_batch(batch_id)
batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
# batch processing is async. once cancelling, only allow "cancelled" status updates
if batch.status == "cancelling" and updates.get("status") != "cancelled":
@ -532,7 +537,7 @@ class ReferenceBatchesImpl(Batches):
async def _process_batch_impl(self, batch_id: str) -> None:
"""Implementation of batch processing logic."""
errors: list[BatchError] = []
batch = await self.retrieve_batch(batch_id)
batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
errors, requests = await self._validate_input(batch)
if errors:

View file

@ -19,7 +19,13 @@ except ImportError as e:
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
# Import models for re-export
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest, ListBatchesResponse
from llama_stack_api.batches.models import (
CancelBatchRequest,
CreateBatchRequest,
ListBatchesRequest,
ListBatchesResponse,
RetrieveBatchRequest,
)
@runtime_checkable
@ -42,9 +48,15 @@ class Batches(Protocol):
request: CreateBatchRequest,
) -> BatchObject: ...
async def retrieve_batch(self, batch_id: str) -> BatchObject: ...
async def retrieve_batch(
self,
request: RetrieveBatchRequest,
) -> BatchObject: ...
async def cancel_batch(self, batch_id: str) -> BatchObject: ...
async def cancel_batch(
self,
request: CancelBatchRequest,
) -> BatchObject: ...
async def list_batches(
self,
@ -52,4 +64,12 @@ class Batches(Protocol):
) -> ListBatchesResponse: ...
__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse"]
__all__ = [
"Batches",
"BatchObject",
"CreateBatchRequest",
"ListBatchesRequest",
"RetrieveBatchRequest",
"CancelBatchRequest",
"ListBatchesResponse",
]

View file

@ -47,6 +47,20 @@ class ListBatchesRequest(BaseModel):
limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.")
@json_schema_type
class RetrieveBatchRequest(BaseModel):
"""Request model for retrieving a batch."""
batch_id: str = Field(..., description="The ID of the batch to retrieve.")
@json_schema_type
class CancelBatchRequest(BaseModel):
"""Request model for canceling a batch."""
batch_id: str = Field(..., description="The ID of the batch to cancel.")
@json_schema_type
class ListBatchesResponse(BaseModel):
"""Response containing a list of batch objects."""
@ -58,4 +72,11 @@ class ListBatchesResponse(BaseModel):
has_more: bool = Field(default=False, description="Whether there are more batches available")
__all__ = ["CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse", "BatchObject"]
__all__ = [
"CreateBatchRequest",
"ListBatchesRequest",
"RetrieveBatchRequest",
"CancelBatchRequest",
"ListBatchesResponse",
"BatchObject",
]

View file

@ -14,10 +14,15 @@ all API-related code together.
from collections.abc import Callable
from typing import Annotated
from fastapi import APIRouter, Body, Depends, Query
from fastapi import APIRouter, Body, Depends, Path, Query
from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest
from llama_stack_api.batches.models import (
CancelBatchRequest,
CreateBatchRequest,
ListBatchesRequest,
RetrieveBatchRequest,
)
from llama_stack_api.datatypes import Api
from llama_stack_api.router_utils import standard_responses
from llama_stack_api.version import LLAMA_STACK_API_V1
@ -58,6 +63,12 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
) -> BatchObject:
return await svc.create_batch(request)
def get_retrieve_batch_request(
batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")],
) -> RetrieveBatchRequest:
"""Dependency function to create RetrieveBatchRequest from path parameter."""
return RetrieveBatchRequest(batch_id=batch_id)
@router.get(
"/batches/{batch_id}",
response_model=BatchObject,
@ -68,10 +79,16 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
},
)
async def retrieve_batch(
batch_id: str,
request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)],
svc: Annotated[Batches, Depends(get_batch_service)],
) -> BatchObject:
return await svc.retrieve_batch(batch_id)
return await svc.retrieve_batch(request)
def get_cancel_batch_request(
batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")],
) -> CancelBatchRequest:
"""Dependency function to create CancelBatchRequest from path parameter."""
return CancelBatchRequest(batch_id=batch_id)
@router.post(
"/batches/{batch_id}/cancel",
@ -83,10 +100,10 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
},
)
async def cancel_batch(
batch_id: str,
request: Annotated[CancelBatchRequest, Depends(get_cancel_batch_request)],
svc: Annotated[Batches, Depends(get_batch_service)],
) -> BatchObject:
return await svc.cancel_batch(batch_id)
return await svc.cancel_batch(request)
def get_list_batches_request(
after: Annotated[