mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
chore: same as previous commit but for more fields
Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
30cab02083
commit
20030429e7
9 changed files with 210 additions and 25 deletions
|
|
@ -125,7 +125,9 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The ID of the batch to retrieve.
|
||||
title: Batch Id
|
||||
description: The ID of the batch to retrieve.
|
||||
/v1/batches/{batch_id}/cancel:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -158,7 +160,9 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The ID of the batch to cancel.
|
||||
title: Batch Id
|
||||
description: The ID of the batch to cancel.
|
||||
/v1/chat/completions:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -12710,6 +12714,28 @@ components:
|
|||
type: integer
|
||||
title: ListBatchesRequest
|
||||
type: object
|
||||
RetrieveBatchRequest:
|
||||
description: Request model for retrieving a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to retrieve.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: RetrieveBatchRequest
|
||||
type: object
|
||||
CancelBatchRequest:
|
||||
description: Request model for canceling a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to cancel.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: CancelBatchRequest
|
||||
type: object
|
||||
DialogType:
|
||||
description: Parameter type for dialog data with semantic output labels.
|
||||
properties:
|
||||
|
|
|
|||
22
docs/static/deprecated-llama-stack-spec.yaml
vendored
22
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
|
@ -9547,6 +9547,28 @@ components:
|
|||
type: integer
|
||||
title: ListBatchesRequest
|
||||
type: object
|
||||
RetrieveBatchRequest:
|
||||
description: Request model for retrieving a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to retrieve.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: RetrieveBatchRequest
|
||||
type: object
|
||||
CancelBatchRequest:
|
||||
description: Request model for canceling a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to cancel.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: CancelBatchRequest
|
||||
type: object
|
||||
DialogType:
|
||||
description: Parameter type for dialog data with semantic output labels.
|
||||
properties:
|
||||
|
|
|
|||
22
docs/static/experimental-llama-stack-spec.yaml
vendored
22
docs/static/experimental-llama-stack-spec.yaml
vendored
|
|
@ -8560,6 +8560,28 @@ components:
|
|||
type: integer
|
||||
title: ListBatchesRequest
|
||||
type: object
|
||||
RetrieveBatchRequest:
|
||||
description: Request model for retrieving a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to retrieve.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: RetrieveBatchRequest
|
||||
type: object
|
||||
CancelBatchRequest:
|
||||
description: Request model for canceling a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to cancel.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: CancelBatchRequest
|
||||
type: object
|
||||
DialogType:
|
||||
description: Parameter type for dialog data with semantic output labels.
|
||||
properties:
|
||||
|
|
|
|||
26
docs/static/llama-stack-spec.yaml
vendored
26
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -123,7 +123,9 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The ID of the batch to retrieve.
|
||||
title: Batch Id
|
||||
description: The ID of the batch to retrieve.
|
||||
/v1/batches/{batch_id}/cancel:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -156,7 +158,9 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The ID of the batch to cancel.
|
||||
title: Batch Id
|
||||
description: The ID of the batch to cancel.
|
||||
/v1/chat/completions:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -11439,6 +11443,28 @@ components:
|
|||
type: integer
|
||||
title: ListBatchesRequest
|
||||
type: object
|
||||
RetrieveBatchRequest:
|
||||
description: Request model for retrieving a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to retrieve.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: RetrieveBatchRequest
|
||||
type: object
|
||||
CancelBatchRequest:
|
||||
description: Request model for canceling a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to cancel.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: CancelBatchRequest
|
||||
type: object
|
||||
DialogType:
|
||||
description: Parameter type for dialog data with semantic output labels.
|
||||
properties:
|
||||
|
|
|
|||
26
docs/static/stainless-llama-stack-spec.yaml
vendored
26
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -125,7 +125,9 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The ID of the batch to retrieve.
|
||||
title: Batch Id
|
||||
description: The ID of the batch to retrieve.
|
||||
/v1/batches/{batch_id}/cancel:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -158,7 +160,9 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The ID of the batch to cancel.
|
||||
title: Batch Id
|
||||
description: The ID of the batch to cancel.
|
||||
/v1/chat/completions:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -12710,6 +12714,28 @@ components:
|
|||
type: integer
|
||||
title: ListBatchesRequest
|
||||
type: object
|
||||
RetrieveBatchRequest:
|
||||
description: Request model for retrieving a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to retrieve.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: RetrieveBatchRequest
|
||||
type: object
|
||||
CancelBatchRequest:
|
||||
description: Request model for canceling a batch.
|
||||
properties:
|
||||
batch_id:
|
||||
description: The ID of the batch to cancel.
|
||||
title: Batch Id
|
||||
type: string
|
||||
required:
|
||||
- batch_id
|
||||
title: CancelBatchRequest
|
||||
type: object
|
||||
DialogType:
|
||||
description: Parameter type for dialog data with semantic output labels.
|
||||
properties:
|
||||
|
|
|
|||
|
|
@ -38,7 +38,12 @@ from llama_stack_api import (
|
|||
OpenAIUserMessageParam,
|
||||
ResourceNotFoundError,
|
||||
)
|
||||
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest
|
||||
from llama_stack_api.batches.models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
|
||||
from .config import ReferenceBatchesImplConfig
|
||||
|
||||
|
|
@ -203,7 +208,7 @@ class ReferenceBatchesImpl(Batches):
|
|||
batch_id = f"batch_{hash_digest}"
|
||||
|
||||
try:
|
||||
existing_batch = await self.retrieve_batch(batch_id)
|
||||
existing_batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
|
||||
|
||||
if (
|
||||
existing_batch.input_file_id != request.input_file_id
|
||||
|
|
@ -244,23 +249,23 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
return batch
|
||||
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
async def cancel_batch(self, request: CancelBatchRequest) -> BatchObject:
|
||||
"""Cancel a batch that is in progress."""
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=request.batch_id))
|
||||
|
||||
if batch.status in ["cancelled", "cancelling"]:
|
||||
return batch
|
||||
|
||||
if batch.status in ["completed", "failed", "expired"]:
|
||||
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
|
||||
raise ConflictError(f"Cannot cancel batch '{request.batch_id}' with status '{batch.status}'")
|
||||
|
||||
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
|
||||
await self._update_batch(request.batch_id, status="cancelling", cancelling_at=int(time.time()))
|
||||
|
||||
if batch_id in self._processing_tasks:
|
||||
self._processing_tasks[batch_id].cancel()
|
||||
if request.batch_id in self._processing_tasks:
|
||||
self._processing_tasks[request.batch_id].cancel()
|
||||
# note: task removal and status="cancelled" handled in finally block of _process_batch
|
||||
|
||||
return await self.retrieve_batch(batch_id)
|
||||
return await self.retrieve_batch(RetrieveBatchRequest(batch_id=request.batch_id))
|
||||
|
||||
async def list_batches(
|
||||
self,
|
||||
|
|
@ -300,11 +305,11 @@ class ReferenceBatchesImpl(Batches):
|
|||
has_more=has_more,
|
||||
)
|
||||
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
async def retrieve_batch(self, request: RetrieveBatchRequest) -> BatchObject:
|
||||
"""Retrieve information about a specific batch."""
|
||||
batch_data = await self.kvstore.get(f"batch:{batch_id}")
|
||||
batch_data = await self.kvstore.get(f"batch:{request.batch_id}")
|
||||
if not batch_data:
|
||||
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
|
||||
raise ResourceNotFoundError(request.batch_id, "Batch", "batches.list()")
|
||||
|
||||
return BatchObject.model_validate_json(batch_data)
|
||||
|
||||
|
|
@ -312,7 +317,7 @@ class ReferenceBatchesImpl(Batches):
|
|||
"""Update batch fields in kvstore."""
|
||||
async with self._update_batch_lock:
|
||||
try:
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
|
||||
|
||||
# batch processing is async. once cancelling, only allow "cancelled" status updates
|
||||
if batch.status == "cancelling" and updates.get("status") != "cancelled":
|
||||
|
|
@ -532,7 +537,7 @@ class ReferenceBatchesImpl(Batches):
|
|||
async def _process_batch_impl(self, batch_id: str) -> None:
|
||||
"""Implementation of batch processing logic."""
|
||||
errors: list[BatchError] = []
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id))
|
||||
|
||||
errors, requests = await self._validate_input(batch)
|
||||
if errors:
|
||||
|
|
|
|||
|
|
@ -19,7 +19,13 @@ except ImportError as e:
|
|||
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||
|
||||
# Import models for re-export
|
||||
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest, ListBatchesResponse
|
||||
from llama_stack_api.batches.models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
ListBatchesResponse,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
@ -42,9 +48,15 @@ class Batches(Protocol):
|
|||
request: CreateBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject: ...
|
||||
async def retrieve_batch(
|
||||
self,
|
||||
request: RetrieveBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject: ...
|
||||
async def cancel_batch(
|
||||
self,
|
||||
request: CancelBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def list_batches(
|
||||
self,
|
||||
|
|
@ -52,4 +64,12 @@ class Batches(Protocol):
|
|||
) -> ListBatchesResponse: ...
|
||||
|
||||
|
||||
__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse"]
|
||||
__all__ = [
|
||||
"Batches",
|
||||
"BatchObject",
|
||||
"CreateBatchRequest",
|
||||
"ListBatchesRequest",
|
||||
"RetrieveBatchRequest",
|
||||
"CancelBatchRequest",
|
||||
"ListBatchesResponse",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -47,6 +47,20 @@ class ListBatchesRequest(BaseModel):
|
|||
limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RetrieveBatchRequest(BaseModel):
|
||||
"""Request model for retrieving a batch."""
|
||||
|
||||
batch_id: str = Field(..., description="The ID of the batch to retrieve.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CancelBatchRequest(BaseModel):
|
||||
"""Request model for canceling a batch."""
|
||||
|
||||
batch_id: str = Field(..., description="The ID of the batch to cancel.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesResponse(BaseModel):
|
||||
"""Response containing a list of batch objects."""
|
||||
|
|
@ -58,4 +72,11 @@ class ListBatchesResponse(BaseModel):
|
|||
has_more: bool = Field(default=False, description="Whether there are more batches available")
|
||||
|
||||
|
||||
__all__ = ["CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse", "BatchObject"]
|
||||
__all__ = [
|
||||
"CreateBatchRequest",
|
||||
"ListBatchesRequest",
|
||||
"RetrieveBatchRequest",
|
||||
"CancelBatchRequest",
|
||||
"ListBatchesResponse",
|
||||
"BatchObject",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -14,10 +14,15 @@ all API-related code together.
|
|||
from collections.abc import Callable
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Query
|
||||
from fastapi import APIRouter, Body, Depends, Path, Query
|
||||
|
||||
from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse
|
||||
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest
|
||||
from llama_stack_api.batches.models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
from llama_stack_api.datatypes import Api
|
||||
from llama_stack_api.router_utils import standard_responses
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
|
@ -58,6 +63,12 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
|
|||
) -> BatchObject:
|
||||
return await svc.create_batch(request)
|
||||
|
||||
def get_retrieve_batch_request(
|
||||
batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")],
|
||||
) -> RetrieveBatchRequest:
|
||||
"""Dependency function to create RetrieveBatchRequest from path parameter."""
|
||||
return RetrieveBatchRequest(batch_id=batch_id)
|
||||
|
||||
@router.get(
|
||||
"/batches/{batch_id}",
|
||||
response_model=BatchObject,
|
||||
|
|
@ -68,10 +79,16 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
|
|||
},
|
||||
)
|
||||
async def retrieve_batch(
|
||||
batch_id: str,
|
||||
request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)],
|
||||
svc: Annotated[Batches, Depends(get_batch_service)],
|
||||
) -> BatchObject:
|
||||
return await svc.retrieve_batch(batch_id)
|
||||
return await svc.retrieve_batch(request)
|
||||
|
||||
def get_cancel_batch_request(
|
||||
batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")],
|
||||
) -> CancelBatchRequest:
|
||||
"""Dependency function to create CancelBatchRequest from path parameter."""
|
||||
return CancelBatchRequest(batch_id=batch_id)
|
||||
|
||||
@router.post(
|
||||
"/batches/{batch_id}/cancel",
|
||||
|
|
@ -83,10 +100,10 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
|
|||
},
|
||||
)
|
||||
async def cancel_batch(
|
||||
batch_id: str,
|
||||
request: Annotated[CancelBatchRequest, Depends(get_cancel_batch_request)],
|
||||
svc: Annotated[Batches, Depends(get_batch_service)],
|
||||
) -> BatchObject:
|
||||
return await svc.cancel_batch(batch_id)
|
||||
return await svc.cancel_batch(request)
|
||||
|
||||
def get_list_batches_request(
|
||||
after: Annotated[
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue