mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-21 09:23:13 +00:00
Add complete batches API implementation with protocol, providers, and tests: Core Infrastructure: - Add batches API protocol using OpenAI Batch types directly - Add Api.batches enum value and protocol mapping in resolver - Add OpenAI "batch" file purpose support - Include proper error handling (ConflictError, ResourceNotFoundError) Reference Provider: - Add ReferenceBatchesImpl with full CRUD operations (create, retrieve, cancel, list) - Implement background batch processing with configurable concurrency - Add SQLite KVStore backend for persistence - Support /v1/chat/completions endpoint with request validation Comprehensive Test Suite: - Add unit tests for provider implementation with validation - Add integration tests for end-to-end batch processing workflows - Add error handling tests for validation, malformed inputs, and edge cases Configuration: - Add max_concurrent_batches and max_concurrent_requests_per_batch options - Add provider documentation with sample configurations Test with - ``` $ uv run llama stack build --image-type venv --providers inference=YOU_PICK,files=inline::localfs,batches=inline::reference --run & $ LLAMA_STACK_CONFIG=http://localhost:8321 uv run pytest tests/unit/providers/batches tests/integration/batches --text-model YOU_PICK ``` addresses #3066 --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
580 lines
23 KiB
Python
580 lines
23 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import itertools
|
|
import json
|
|
import time
|
|
import uuid
|
|
from io import BytesIO
|
|
from typing import Any, Literal
|
|
|
|
from openai.types.batch import BatchError, Errors
|
|
from pydantic import BaseModel
|
|
|
|
from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
|
|
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
|
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
|
from llama_stack.apis.inference import (
|
|
Inference,
|
|
OpenAIAssistantMessageParam,
|
|
OpenAIDeveloperMessageParam,
|
|
OpenAIMessageParam,
|
|
OpenAISystemMessageParam,
|
|
OpenAIToolMessageParam,
|
|
OpenAIUserMessageParam,
|
|
)
|
|
from llama_stack.apis.models import Models
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.utils.kvstore import KVStore
|
|
|
|
from .config import ReferenceBatchesImplConfig
|
|
|
|
BATCH_PREFIX = "batch:"
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class AsyncBytesIO:
|
|
"""
|
|
Async-compatible BytesIO wrapper to allow async file-like operations.
|
|
|
|
We use this when uploading files to the Files API, as it expects an
|
|
async file-like object.
|
|
"""
|
|
|
|
def __init__(self, data: bytes):
|
|
self._buffer = BytesIO(data)
|
|
|
|
async def read(self, n=-1):
|
|
return self._buffer.read(n)
|
|
|
|
async def seek(self, pos, whence=0):
|
|
return self._buffer.seek(pos, whence)
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self._buffer.close()
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._buffer, name)
|
|
|
|
|
|
class BatchRequest(BaseModel):
|
|
line_num: int
|
|
custom_id: str
|
|
method: str
|
|
url: str
|
|
body: dict[str, Any]
|
|
|
|
|
|
def convert_to_openai_message_param(msg: dict[str, Any]) -> OpenAIMessageParam:
|
|
"""Convert a message dictionary to OpenAIMessageParam based on role."""
|
|
role = msg.get("role")
|
|
|
|
if role == "user":
|
|
return OpenAIUserMessageParam(**msg)
|
|
elif role == "system":
|
|
return OpenAISystemMessageParam(**msg)
|
|
elif role == "assistant":
|
|
return OpenAIAssistantMessageParam(**msg)
|
|
elif role == "tool":
|
|
return OpenAIToolMessageParam(**msg)
|
|
elif role == "developer":
|
|
return OpenAIDeveloperMessageParam(**msg)
|
|
else:
|
|
raise ValueError(f"Unknown message role: {role}")
|
|
|
|
|
|
class ReferenceBatchesImpl(Batches):
|
|
"""Reference implementation of the Batches API.
|
|
|
|
This implementation processes batch files by making individual requests
|
|
to the inference API and generates output files with results.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: ReferenceBatchesImplConfig,
|
|
inference_api: Inference,
|
|
files_api: Files,
|
|
models_api: Models,
|
|
kvstore: KVStore,
|
|
) -> None:
|
|
self.config = config
|
|
self.kvstore = kvstore
|
|
self.inference_api = inference_api
|
|
self.files_api = files_api
|
|
self.models_api = models_api
|
|
self._processing_tasks: dict[str, asyncio.Task] = {}
|
|
self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches)
|
|
self._update_batch_lock = asyncio.Lock()
|
|
|
|
# this is to allow tests to disable background processing
|
|
self.process_batches = True
|
|
|
|
async def initialize(self) -> None:
|
|
# TODO: start background processing of existing tasks
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
"""Shutdown the batches provider."""
|
|
if self._processing_tasks:
|
|
# don't cancel tasks - just let them stop naturally on shutdown
|
|
# cancelling would mark batches as "cancelled" in the database
|
|
logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks")
|
|
|
|
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
|
|
async def create_batch(
|
|
self,
|
|
input_file_id: str,
|
|
endpoint: str,
|
|
completion_window: Literal["24h"],
|
|
metadata: dict[str, str] | None = None,
|
|
) -> BatchObject:
|
|
"""
|
|
Create a new batch for processing multiple API requests.
|
|
|
|
Error handling by levels -
|
|
0. Input param handling, results in 40x errors before processing, e.g.
|
|
- Wrong completion_window
|
|
- Invalid metadata types
|
|
- Unknown endpoint
|
|
-> no batch created
|
|
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
|
|
- input_file_id missing
|
|
- invalid json in file
|
|
- missing custom_id, method, url, body
|
|
- invalid model
|
|
- streaming
|
|
-> batch created, validation sends to failed status
|
|
2. Processing errors, result in error_file_id entries, e.g.
|
|
- Any error returned from inference endpoint
|
|
-> batch created, goes to completed status
|
|
"""
|
|
|
|
# TODO: set expiration time for garbage collection
|
|
|
|
if endpoint not in ["/v1/chat/completions"]:
|
|
raise ValueError(
|
|
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
|
|
)
|
|
|
|
if completion_window != "24h":
|
|
raise ValueError(
|
|
f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
|
|
)
|
|
|
|
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
|
|
current_time = int(time.time())
|
|
|
|
batch = BatchObject(
|
|
id=batch_id,
|
|
object="batch",
|
|
endpoint=endpoint,
|
|
input_file_id=input_file_id,
|
|
completion_window=completion_window,
|
|
status="validating",
|
|
created_at=current_time,
|
|
metadata=metadata,
|
|
)
|
|
|
|
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
|
|
|
|
if self.process_batches:
|
|
task = asyncio.create_task(self._process_batch(batch_id))
|
|
self._processing_tasks[batch_id] = task
|
|
|
|
return batch
|
|
|
|
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
|
"""Cancel a batch that is in progress."""
|
|
batch = await self.retrieve_batch(batch_id)
|
|
|
|
if batch.status in ["cancelled", "cancelling"]:
|
|
return batch
|
|
|
|
if batch.status in ["completed", "failed", "expired"]:
|
|
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
|
|
|
|
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
|
|
|
|
if batch_id in self._processing_tasks:
|
|
self._processing_tasks[batch_id].cancel()
|
|
# note: task removal and status="cancelled" handled in finally block of _process_batch
|
|
|
|
return await self.retrieve_batch(batch_id)
|
|
|
|
async def list_batches(
|
|
self,
|
|
after: str | None = None,
|
|
limit: int = 20,
|
|
) -> ListBatchesResponse:
|
|
"""
|
|
List all batches, eventually only for the current user.
|
|
|
|
With no notion of user, we return all batches.
|
|
"""
|
|
batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff")
|
|
|
|
batches = []
|
|
for batch_data in batch_values:
|
|
if batch_data:
|
|
batches.append(BatchObject.model_validate_json(batch_data))
|
|
|
|
batches.sort(key=lambda b: b.created_at, reverse=True)
|
|
|
|
start_idx = 0
|
|
if after:
|
|
for i, batch in enumerate(batches):
|
|
if batch.id == after:
|
|
start_idx = i + 1
|
|
break
|
|
|
|
page_batches = batches[start_idx : start_idx + limit]
|
|
has_more = (start_idx + limit) < len(batches)
|
|
|
|
first_id = page_batches[0].id if page_batches else None
|
|
last_id = page_batches[-1].id if page_batches else None
|
|
|
|
return ListBatchesResponse(
|
|
data=page_batches,
|
|
first_id=first_id,
|
|
last_id=last_id,
|
|
has_more=has_more,
|
|
)
|
|
|
|
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
|
"""Retrieve information about a specific batch."""
|
|
batch_data = await self.kvstore.get(f"batch:{batch_id}")
|
|
if not batch_data:
|
|
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
|
|
|
|
return BatchObject.model_validate_json(batch_data)
|
|
|
|
async def _update_batch(self, batch_id: str, **updates) -> None:
|
|
"""Update batch fields in kvstore."""
|
|
async with self._update_batch_lock:
|
|
try:
|
|
batch = await self.retrieve_batch(batch_id)
|
|
|
|
# batch processing is async. once cancelling, only allow "cancelled" status updates
|
|
if batch.status == "cancelling" and updates.get("status") != "cancelled":
|
|
logger.info(
|
|
f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}"
|
|
)
|
|
return
|
|
|
|
if "errors" in updates:
|
|
updates["errors"] = updates["errors"].model_dump()
|
|
|
|
batch_dict = batch.model_dump()
|
|
batch_dict.update(updates)
|
|
|
|
await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict))
|
|
except Exception as e:
|
|
logger.error(f"Failed to update batch {batch_id}: {e}")
|
|
|
|
async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]:
|
|
"""
|
|
Read & validate input, return errors and valid input.
|
|
|
|
Validation of
|
|
- input_file_id existance
|
|
- valid json
|
|
- custom_id, method, url, body presence and valid
|
|
- no streaming
|
|
"""
|
|
requests: list[BatchRequest] = []
|
|
errors: list[BatchError] = []
|
|
try:
|
|
await self.files_api.openai_retrieve_file(batch.input_file_id)
|
|
except Exception:
|
|
errors.append(
|
|
BatchError(
|
|
code="invalid_request",
|
|
line=None,
|
|
message=f"Cannot find file {batch.input_file_id}.",
|
|
param="input_file_id",
|
|
)
|
|
)
|
|
return errors, requests
|
|
|
|
# TODO(SECURITY): do something about large files
|
|
file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
|
|
file_content = file_content_response.body.decode("utf-8")
|
|
for line_num, line in enumerate(file_content.strip().split("\n"), 1):
|
|
if line.strip(): # skip empty lines
|
|
try:
|
|
request = json.loads(line)
|
|
|
|
if not isinstance(request, dict):
|
|
errors.append(
|
|
BatchError(
|
|
code="invalid_request",
|
|
line=line_num,
|
|
message="Each line must be a JSON dictionary object",
|
|
)
|
|
)
|
|
continue
|
|
|
|
valid = True
|
|
|
|
for param, expected_type, type_string in [
|
|
("custom_id", str, "string"),
|
|
("method", str, "string"),
|
|
("url", str, "string"),
|
|
("body", dict, "JSON dictionary object"),
|
|
]:
|
|
if param not in request:
|
|
errors.append(
|
|
BatchError(
|
|
code="missing_required_parameter",
|
|
line=line_num,
|
|
message=f"Missing required parameter: {param}",
|
|
param=param,
|
|
)
|
|
)
|
|
valid = False
|
|
elif not isinstance(request[param], expected_type):
|
|
param_name = "URL" if param == "url" else param.capitalize()
|
|
errors.append(
|
|
BatchError(
|
|
code="invalid_request",
|
|
line=line_num,
|
|
message=f"{param_name} must be a {type_string}",
|
|
param=param,
|
|
)
|
|
)
|
|
valid = False
|
|
|
|
if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint:
|
|
errors.append(
|
|
BatchError(
|
|
code="invalid_url",
|
|
line=line_num,
|
|
message="URL provided for this request does not match the batch endpoint",
|
|
param="url",
|
|
)
|
|
)
|
|
valid = False
|
|
|
|
if (body := request.get("body")) and isinstance(body, dict):
|
|
if body.get("stream", False):
|
|
errors.append(
|
|
BatchError(
|
|
code="streaming_unsupported",
|
|
line=line_num,
|
|
message="Streaming is not supported in batch processing",
|
|
param="body.stream",
|
|
)
|
|
)
|
|
valid = False
|
|
|
|
for param, expected_type, type_string in [
|
|
("model", str, "a string"),
|
|
# messages is specific to /v1/chat/completions
|
|
# we could skip validating messages here and let inference fail. however,
|
|
# that would be a very expensive way to find out messages is wrong.
|
|
("messages", list, "an array"), # TODO: allow messages to be a string?
|
|
]:
|
|
if param not in body:
|
|
errors.append(
|
|
BatchError(
|
|
code="invalid_request",
|
|
line=line_num,
|
|
message=f"{param.capitalize()} parameter is required",
|
|
param=f"body.{param}",
|
|
)
|
|
)
|
|
valid = False
|
|
elif not isinstance(body[param], expected_type):
|
|
errors.append(
|
|
BatchError(
|
|
code="invalid_request",
|
|
line=line_num,
|
|
message=f"{param.capitalize()} must be {type_string}",
|
|
param=f"body.{param}",
|
|
)
|
|
)
|
|
valid = False
|
|
|
|
if "model" in body and isinstance(body["model"], str):
|
|
try:
|
|
await self.models_api.get_model(body["model"])
|
|
except Exception:
|
|
errors.append(
|
|
BatchError(
|
|
code="model_not_found",
|
|
line=line_num,
|
|
message=f"Model '{body['model']}' does not exist or is not supported",
|
|
param="body.model",
|
|
)
|
|
)
|
|
valid = False
|
|
|
|
if valid:
|
|
assert isinstance(url, str), "URL must be a string" # for mypy
|
|
assert isinstance(body, dict), "Body must be a dictionary" # for mypy
|
|
requests.append(
|
|
BatchRequest(
|
|
line_num=line_num,
|
|
url=url,
|
|
method=request["method"],
|
|
custom_id=request["custom_id"],
|
|
body=body,
|
|
),
|
|
)
|
|
except json.JSONDecodeError:
|
|
errors.append(
|
|
BatchError(
|
|
code="invalid_json_line",
|
|
line=line_num,
|
|
message="This line is not parseable as valid JSON.",
|
|
)
|
|
)
|
|
|
|
return errors, requests
|
|
|
|
async def _process_batch(self, batch_id: str) -> None:
|
|
"""Background task to process a batch of requests."""
|
|
try:
|
|
logger.info(f"Starting batch processing for {batch_id}")
|
|
async with self._batch_semaphore: # semaphore to limit concurrency
|
|
logger.info(f"Acquired semaphore for batch {batch_id}")
|
|
await self._process_batch_impl(batch_id)
|
|
except asyncio.CancelledError:
|
|
logger.info(f"Batch processing cancelled for {batch_id}")
|
|
await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time()))
|
|
except Exception as e:
|
|
logger.error(f"Batch processing failed for {batch_id}: {e}")
|
|
await self._update_batch(
|
|
batch_id,
|
|
status="failed",
|
|
failed_at=int(time.time()),
|
|
errors=Errors(data=[BatchError(code="internal_error", message=str(e))]),
|
|
)
|
|
finally:
|
|
self._processing_tasks.pop(batch_id, None)
|
|
|
|
async def _process_batch_impl(self, batch_id: str) -> None:
|
|
"""Implementation of batch processing logic."""
|
|
errors: list[BatchError] = []
|
|
batch = await self.retrieve_batch(batch_id)
|
|
|
|
errors, requests = await self._validate_input(batch)
|
|
if errors:
|
|
await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors))
|
|
logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors")
|
|
return
|
|
|
|
logger.info(f"Processing {len(requests)} requests for batch {batch_id}")
|
|
|
|
total_requests = len(requests)
|
|
await self._update_batch(
|
|
batch_id,
|
|
status="in_progress",
|
|
request_counts={"total": total_requests, "completed": 0, "failed": 0},
|
|
)
|
|
|
|
error_results = []
|
|
success_results = []
|
|
completed_count = 0
|
|
failed_count = 0
|
|
|
|
for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch):
|
|
# we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled
|
|
async with asyncio.TaskGroup() as tg:
|
|
chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk]
|
|
|
|
chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True)
|
|
|
|
for result in chunk_results:
|
|
if isinstance(result, dict) and result.get("error") is not None: # error response from inference
|
|
failed_count += 1
|
|
error_results.append(result)
|
|
elif isinstance(result, dict) and result.get("response") is not None: # successful inference
|
|
completed_count += 1
|
|
success_results.append(result)
|
|
else: # unexpected result
|
|
failed_count += 1
|
|
errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}"))
|
|
|
|
await self._update_batch(
|
|
batch_id,
|
|
request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count},
|
|
)
|
|
|
|
if errors:
|
|
await self._update_batch(
|
|
batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)
|
|
)
|
|
return
|
|
|
|
try:
|
|
output_file_id = await self._create_output_file(batch_id, success_results, "success")
|
|
await self._update_batch(batch_id, output_file_id=output_file_id)
|
|
|
|
error_file_id = await self._create_output_file(batch_id, error_results, "error")
|
|
await self._update_batch(batch_id, error_file_id=error_file_id)
|
|
|
|
await self._update_batch(batch_id, status="completed", completed_at=int(time.time()))
|
|
|
|
logger.info(
|
|
f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed"
|
|
)
|
|
except Exception as e:
|
|
# note: errors is empty at this point, so we don't lose anything by ignoring it
|
|
await self._update_batch(
|
|
batch_id,
|
|
status="failed",
|
|
failed_at=int(time.time()),
|
|
errors=Errors(data=[BatchError(code="output_failed", message=str(e))]),
|
|
)
|
|
|
|
async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict:
|
|
"""Process a single request from the batch."""
|
|
request_id = f"batch_req_{batch_id}_{request.line_num}"
|
|
|
|
try:
|
|
# TODO(SECURITY): review body for security issues
|
|
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
|
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
|
|
|
# this is for mypy, we don't allow streaming so we'll get the right type
|
|
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
|
return {
|
|
"id": request_id,
|
|
"custom_id": request.custom_id,
|
|
"response": {
|
|
"status_code": 200,
|
|
"request_id": request_id, # TODO: should this be different?
|
|
"body": chat_response.model_dump_json(),
|
|
},
|
|
}
|
|
except Exception as e:
|
|
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
|
return {
|
|
"id": request_id,
|
|
"custom_id": request.custom_id,
|
|
"error": {"type": "request_failed", "message": str(e)},
|
|
}
|
|
|
|
async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str:
|
|
"""
|
|
Create an output file with batch results.
|
|
|
|
This function filters results based on the specified file_type
|
|
and uploads the file to the Files API.
|
|
"""
|
|
output_lines = [json.dumps(result) for result in results]
|
|
|
|
with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer:
|
|
file_buffer.filename = f"{batch_id}_{file_type}.jsonl"
|
|
uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
|
|
return uploaded_file.id
|