# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import asyncio import itertools import json import time import uuid from io import BytesIO from typing import Any, Literal from openai.types.batch import BatchError, Errors from pydantic import BaseModel from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError from llama_stack.apis.files import Files, OpenAIFilePurpose from llama_stack.apis.inference import ( Inference, OpenAIAssistantMessageParam, OpenAIDeveloperMessageParam, OpenAIMessageParam, OpenAISystemMessageParam, OpenAIToolMessageParam, OpenAIUserMessageParam, ) from llama_stack.apis.models import Models from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore from .config import ReferenceBatchesImplConfig BATCH_PREFIX = "batch:" logger = get_logger(__name__) class AsyncBytesIO: """ Async-compatible BytesIO wrapper to allow async file-like operations. We use this when uploading files to the Files API, as it expects an async file-like object. """ def __init__(self, data: bytes): self._buffer = BytesIO(data) async def read(self, n=-1): return self._buffer.read(n) async def seek(self, pos, whence=0): return self._buffer.seek(pos, whence) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self._buffer.close() def __getattr__(self, name): return getattr(self._buffer, name) class BatchRequest(BaseModel): line_num: int custom_id: str method: str url: str body: dict[str, Any] def convert_to_openai_message_param(msg: dict[str, Any]) -> OpenAIMessageParam: """Convert a message dictionary to OpenAIMessageParam based on role.""" role = msg.get("role") if role == "user": return OpenAIUserMessageParam(**msg) elif role == "system": return OpenAISystemMessageParam(**msg) elif role == "assistant": return OpenAIAssistantMessageParam(**msg) elif role == "tool": return OpenAIToolMessageParam(**msg) elif role == "developer": return OpenAIDeveloperMessageParam(**msg) else: raise ValueError(f"Unknown message role: {role}") class ReferenceBatchesImpl(Batches): """Reference implementation of the Batches API. This implementation processes batch files by making individual requests to the inference API and generates output files with results. """ def __init__( self, config: ReferenceBatchesImplConfig, inference_api: Inference, files_api: Files, models_api: Models, kvstore: KVStore, ) -> None: self.config = config self.kvstore = kvstore self.inference_api = inference_api self.files_api = files_api self.models_api = models_api self._processing_tasks: dict[str, asyncio.Task] = {} self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches) self._update_batch_lock = asyncio.Lock() # this is to allow tests to disable background processing self.process_batches = True async def initialize(self) -> None: # TODO: start background processing of existing tasks pass async def shutdown(self) -> None: """Shutdown the batches provider.""" if self._processing_tasks: # don't cancel tasks - just let them stop naturally on shutdown # cancelling would mark batches as "cancelled" in the database logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks") # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions async def create_batch( self, input_file_id: str, endpoint: str, completion_window: Literal["24h"], metadata: dict[str, str] | None = None, ) -> BatchObject: """ Create a new batch for processing multiple API requests. Error handling by levels - 0. Input param handling, results in 40x errors before processing, e.g. - Wrong completion_window - Invalid metadata types - Unknown endpoint -> no batch created 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g. - input_file_id missing - invalid json in file - missing custom_id, method, url, body - invalid model - streaming -> batch created, validation sends to failed status 2. Processing errors, result in error_file_id entries, e.g. - Any error returned from inference endpoint -> batch created, goes to completed status """ # TODO: set expiration time for garbage collection if endpoint not in ["/v1/chat/completions"]: raise ValueError( f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint", ) if completion_window != "24h": raise ValueError( f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window", ) batch_id = f"batch_{uuid.uuid4().hex[:16]}" current_time = int(time.time()) batch = BatchObject( id=batch_id, object="batch", endpoint=endpoint, input_file_id=input_file_id, completion_window=completion_window, status="validating", created_at=current_time, metadata=metadata, ) await self.kvstore.set(f"batch:{batch_id}", batch.to_json()) if self.process_batches: task = asyncio.create_task(self._process_batch(batch_id)) self._processing_tasks[batch_id] = task return batch async def cancel_batch(self, batch_id: str) -> BatchObject: """Cancel a batch that is in progress.""" batch = await self.retrieve_batch(batch_id) if batch.status in ["cancelled", "cancelling"]: return batch if batch.status in ["completed", "failed", "expired"]: raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'") await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time())) if batch_id in self._processing_tasks: self._processing_tasks[batch_id].cancel() # note: task removal and status="cancelled" handled in finally block of _process_batch return await self.retrieve_batch(batch_id) async def list_batches( self, after: str | None = None, limit: int = 20, ) -> ListBatchesResponse: """ List all batches, eventually only for the current user. With no notion of user, we return all batches. """ batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff") batches = [] for batch_data in batch_values: if batch_data: batches.append(BatchObject.model_validate_json(batch_data)) batches.sort(key=lambda b: b.created_at, reverse=True) start_idx = 0 if after: for i, batch in enumerate(batches): if batch.id == after: start_idx = i + 1 break page_batches = batches[start_idx : start_idx + limit] has_more = (start_idx + limit) < len(batches) first_id = page_batches[0].id if page_batches else None last_id = page_batches[-1].id if page_batches else None return ListBatchesResponse( data=page_batches, first_id=first_id, last_id=last_id, has_more=has_more, ) async def retrieve_batch(self, batch_id: str) -> BatchObject: """Retrieve information about a specific batch.""" batch_data = await self.kvstore.get(f"batch:{batch_id}") if not batch_data: raise ResourceNotFoundError(batch_id, "Batch", "batches.list()") return BatchObject.model_validate_json(batch_data) async def _update_batch(self, batch_id: str, **updates) -> None: """Update batch fields in kvstore.""" async with self._update_batch_lock: try: batch = await self.retrieve_batch(batch_id) # batch processing is async. once cancelling, only allow "cancelled" status updates if batch.status == "cancelling" and updates.get("status") != "cancelled": logger.info( f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}" ) return if "errors" in updates: updates["errors"] = updates["errors"].model_dump() batch_dict = batch.model_dump() batch_dict.update(updates) await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict)) except Exception as e: logger.error(f"Failed to update batch {batch_id}: {e}") async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]: """ Read & validate input, return errors and valid input. Validation of - input_file_id existance - valid json - custom_id, method, url, body presence and valid - no streaming """ requests: list[BatchRequest] = [] errors: list[BatchError] = [] try: await self.files_api.openai_retrieve_file(batch.input_file_id) except Exception: errors.append( BatchError( code="invalid_request", line=None, message=f"Cannot find file {batch.input_file_id}.", param="input_file_id", ) ) return errors, requests # TODO(SECURITY): do something about large files file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id) file_content = file_content_response.body.decode("utf-8") for line_num, line in enumerate(file_content.strip().split("\n"), 1): if line.strip(): # skip empty lines try: request = json.loads(line) if not isinstance(request, dict): errors.append( BatchError( code="invalid_request", line=line_num, message="Each line must be a JSON dictionary object", ) ) continue valid = True for param, expected_type, type_string in [ ("custom_id", str, "string"), ("method", str, "string"), ("url", str, "string"), ("body", dict, "JSON dictionary object"), ]: if param not in request: errors.append( BatchError( code="missing_required_parameter", line=line_num, message=f"Missing required parameter: {param}", param=param, ) ) valid = False elif not isinstance(request[param], expected_type): param_name = "URL" if param == "url" else param.capitalize() errors.append( BatchError( code="invalid_request", line=line_num, message=f"{param_name} must be a {type_string}", param=param, ) ) valid = False if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint: errors.append( BatchError( code="invalid_url", line=line_num, message="URL provided for this request does not match the batch endpoint", param="url", ) ) valid = False if (body := request.get("body")) and isinstance(body, dict): if body.get("stream", False): errors.append( BatchError( code="streaming_unsupported", line=line_num, message="Streaming is not supported in batch processing", param="body.stream", ) ) valid = False for param, expected_type, type_string in [ ("model", str, "a string"), # messages is specific to /v1/chat/completions # we could skip validating messages here and let inference fail. however, # that would be a very expensive way to find out messages is wrong. ("messages", list, "an array"), # TODO: allow messages to be a string? ]: if param not in body: errors.append( BatchError( code="invalid_request", line=line_num, message=f"{param.capitalize()} parameter is required", param=f"body.{param}", ) ) valid = False elif not isinstance(body[param], expected_type): errors.append( BatchError( code="invalid_request", line=line_num, message=f"{param.capitalize()} must be {type_string}", param=f"body.{param}", ) ) valid = False if "model" in body and isinstance(body["model"], str): try: await self.models_api.get_model(body["model"]) except Exception: errors.append( BatchError( code="model_not_found", line=line_num, message=f"Model '{body['model']}' does not exist or is not supported", param="body.model", ) ) valid = False if valid: assert isinstance(url, str), "URL must be a string" # for mypy assert isinstance(body, dict), "Body must be a dictionary" # for mypy requests.append( BatchRequest( line_num=line_num, url=url, method=request["method"], custom_id=request["custom_id"], body=body, ), ) except json.JSONDecodeError: errors.append( BatchError( code="invalid_json_line", line=line_num, message="This line is not parseable as valid JSON.", ) ) return errors, requests async def _process_batch(self, batch_id: str) -> None: """Background task to process a batch of requests.""" try: logger.info(f"Starting batch processing for {batch_id}") async with self._batch_semaphore: # semaphore to limit concurrency logger.info(f"Acquired semaphore for batch {batch_id}") await self._process_batch_impl(batch_id) except asyncio.CancelledError: logger.info(f"Batch processing cancelled for {batch_id}") await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time())) except Exception as e: logger.error(f"Batch processing failed for {batch_id}: {e}") await self._update_batch( batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=[BatchError(code="internal_error", message=str(e))]), ) finally: self._processing_tasks.pop(batch_id, None) async def _process_batch_impl(self, batch_id: str) -> None: """Implementation of batch processing logic.""" errors: list[BatchError] = [] batch = await self.retrieve_batch(batch_id) errors, requests = await self._validate_input(batch) if errors: await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)) logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors") return logger.info(f"Processing {len(requests)} requests for batch {batch_id}") total_requests = len(requests) await self._update_batch( batch_id, status="in_progress", request_counts={"total": total_requests, "completed": 0, "failed": 0}, ) error_results = [] success_results = [] completed_count = 0 failed_count = 0 for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch): # we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled async with asyncio.TaskGroup() as tg: chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk] chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True) for result in chunk_results: if isinstance(result, dict) and result.get("error") is not None: # error response from inference failed_count += 1 error_results.append(result) elif isinstance(result, dict) and result.get("response") is not None: # successful inference completed_count += 1 success_results.append(result) else: # unexpected result failed_count += 1 errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}")) await self._update_batch( batch_id, request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count}, ) if errors: await self._update_batch( batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors) ) return try: output_file_id = await self._create_output_file(batch_id, success_results, "success") await self._update_batch(batch_id, output_file_id=output_file_id) error_file_id = await self._create_output_file(batch_id, error_results, "error") await self._update_batch(batch_id, error_file_id=error_file_id) await self._update_batch(batch_id, status="completed", completed_at=int(time.time())) logger.info( f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed" ) except Exception as e: # note: errors is empty at this point, so we don't lose anything by ignoring it await self._update_batch( batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=[BatchError(code="output_failed", message=str(e))]), ) async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict: """Process a single request from the batch.""" request_id = f"batch_req_{batch_id}_{request.line_num}" try: # TODO(SECURITY): review body for security issues request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] chat_response = await self.inference_api.openai_chat_completion(**request.body) # this is for mypy, we don't allow streaming so we'll get the right type assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" return { "id": request_id, "custom_id": request.custom_id, "response": { "status_code": 200, "request_id": request_id, # TODO: should this be different? "body": chat_response.model_dump_json(), }, } except Exception as e: logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}") return { "id": request_id, "custom_id": request.custom_id, "error": {"type": "request_failed", "message": str(e)}, } async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str: """ Create an output file with batch results. This function filters results based on the specified file_type and uploads the file to the Files API. """ output_lines = [json.dumps(result) for result in results] with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer: file_buffer.filename = f"{batch_id}_{file_type}.jsonl" uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH) return uploaded_file.id