mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
chore: same as previous commit but for more fields
Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
30cab02083
commit
20030429e7
9 changed files with 210 additions and 25 deletions
|
|
@ -19,7 +19,13 @@ except ImportError as e:
|
|||
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||
|
||||
# Import models for re-export
|
||||
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest, ListBatchesResponse
|
||||
from llama_stack_api.batches.models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
ListBatchesResponse,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
@ -42,9 +48,15 @@ class Batches(Protocol):
|
|||
request: CreateBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject: ...
|
||||
async def retrieve_batch(
|
||||
self,
|
||||
request: RetrieveBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject: ...
|
||||
async def cancel_batch(
|
||||
self,
|
||||
request: CancelBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def list_batches(
|
||||
self,
|
||||
|
|
@ -52,4 +64,12 @@ class Batches(Protocol):
|
|||
) -> ListBatchesResponse: ...
|
||||
|
||||
|
||||
__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse"]
|
||||
__all__ = [
|
||||
"Batches",
|
||||
"BatchObject",
|
||||
"CreateBatchRequest",
|
||||
"ListBatchesRequest",
|
||||
"RetrieveBatchRequest",
|
||||
"CancelBatchRequest",
|
||||
"ListBatchesResponse",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -47,6 +47,20 @@ class ListBatchesRequest(BaseModel):
|
|||
limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RetrieveBatchRequest(BaseModel):
|
||||
"""Request model for retrieving a batch."""
|
||||
|
||||
batch_id: str = Field(..., description="The ID of the batch to retrieve.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CancelBatchRequest(BaseModel):
|
||||
"""Request model for canceling a batch."""
|
||||
|
||||
batch_id: str = Field(..., description="The ID of the batch to cancel.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesResponse(BaseModel):
|
||||
"""Response containing a list of batch objects."""
|
||||
|
|
@ -58,4 +72,11 @@ class ListBatchesResponse(BaseModel):
|
|||
has_more: bool = Field(default=False, description="Whether there are more batches available")
|
||||
|
||||
|
||||
__all__ = ["CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse", "BatchObject"]
|
||||
__all__ = [
|
||||
"CreateBatchRequest",
|
||||
"ListBatchesRequest",
|
||||
"RetrieveBatchRequest",
|
||||
"CancelBatchRequest",
|
||||
"ListBatchesResponse",
|
||||
"BatchObject",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -14,10 +14,15 @@ all API-related code together.
|
|||
from collections.abc import Callable
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Query
|
||||
from fastapi import APIRouter, Body, Depends, Path, Query
|
||||
|
||||
from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse
|
||||
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest
|
||||
from llama_stack_api.batches.models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
from llama_stack_api.datatypes import Api
|
||||
from llama_stack_api.router_utils import standard_responses
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
|
@ -58,6 +63,12 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
|
|||
) -> BatchObject:
|
||||
return await svc.create_batch(request)
|
||||
|
||||
def get_retrieve_batch_request(
|
||||
batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")],
|
||||
) -> RetrieveBatchRequest:
|
||||
"""Dependency function to create RetrieveBatchRequest from path parameter."""
|
||||
return RetrieveBatchRequest(batch_id=batch_id)
|
||||
|
||||
@router.get(
|
||||
"/batches/{batch_id}",
|
||||
response_model=BatchObject,
|
||||
|
|
@ -68,10 +79,16 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
|
|||
},
|
||||
)
|
||||
async def retrieve_batch(
|
||||
batch_id: str,
|
||||
request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)],
|
||||
svc: Annotated[Batches, Depends(get_batch_service)],
|
||||
) -> BatchObject:
|
||||
return await svc.retrieve_batch(batch_id)
|
||||
return await svc.retrieve_batch(request)
|
||||
|
||||
def get_cancel_batch_request(
|
||||
batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")],
|
||||
) -> CancelBatchRequest:
|
||||
"""Dependency function to create CancelBatchRequest from path parameter."""
|
||||
return CancelBatchRequest(batch_id=batch_id)
|
||||
|
||||
@router.post(
|
||||
"/batches/{batch_id}/cancel",
|
||||
|
|
@ -83,10 +100,10 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
|
|||
},
|
||||
)
|
||||
async def cancel_batch(
|
||||
batch_id: str,
|
||||
request: Annotated[CancelBatchRequest, Depends(get_cancel_batch_request)],
|
||||
svc: Annotated[Batches, Depends(get_batch_service)],
|
||||
) -> BatchObject:
|
||||
return await svc.cancel_batch(batch_id)
|
||||
return await svc.cancel_batch(request)
|
||||
|
||||
def get_list_batches_request(
|
||||
after: Annotated[
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue