mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
chore: refactor Batches protocol to use request models
This commit refactors the Batches protocol to use Pydantic request models for both create_batch and list_batches methods, improving consistency, readability, and maintainability. - create_batch now accepts a single CreateBatchRequest parameter instead of individual arguments. This aligns the protocol with FastAPI’s request model pattern, allowing the router to pass the request object directly without unpacking parameters. Provider implementations now access fields via request.input_file_id, request.endpoint, etc. - list_batches now accepts a single ListBatchesRequest parameter, replacing individual query parameters. The model includes after and limit fields with proper OpenAPI descriptions. FastAPI automatically parses query parameters into the model for GET requests, keeping router code clean. Provider implementations access fields via request.after and request.limit. Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
00e7ea6c3b
commit
30cab02083
9 changed files with 145 additions and 50 deletions
|
|
@ -11,7 +11,7 @@ Pydantic models are defined in llama_stack_api.batches.models.
|
|||
The FastAPI router is defined in llama_stack_api.batches.routes.
|
||||
"""
|
||||
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
try:
|
||||
from openai.types import Batch as BatchObject
|
||||
|
|
@ -19,7 +19,7 @@ except ImportError as e:
|
|||
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||
|
||||
# Import models for re-export
|
||||
from llama_stack_api.batches.models import ListBatchesResponse
|
||||
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest, ListBatchesResponse
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
@ -39,11 +39,7 @@ class Batches(Protocol):
|
|||
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
idempotency_key: str | None = None,
|
||||
request: CreateBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject: ...
|
||||
|
|
@ -52,9 +48,8 @@ class Batches(Protocol):
|
|||
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int = 20,
|
||||
request: ListBatchesRequest,
|
||||
) -> ListBatchesResponse: ...
|
||||
|
||||
|
||||
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]
|
||||
__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue