mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 17:11:12 +00:00 
			
		
		
		
	
		
			Some checks failed
		
		
	
	Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 0s
				
			Python Package Build Test / build (3.12) (push) Failing after 1s
				
			Unit Tests / unit-tests (3.13) (push) Failing after 4s
				
			SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
				
			SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
				
			Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
				
			Python Package Build Test / build (3.13) (push) Failing after 1s
				
			Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
				
			Vector IO Integration Tests / test-matrix (push) Failing after 5s
				
			Test External API and Providers / test-external (venv) (push) Failing after 5s
				
			Unit Tests / unit-tests (3.12) (push) Failing after 4s
				
			API Conformance Tests / check-schema-compatibility (push) Successful in 10s
				
			UI Tests / ui-tests (22) (push) Successful in 40s
				
			Pre-commit / pre-commit (push) Successful in 1m23s
				
			Applies the same pattern from https://github.com/llamastack/llama-stack/pull/3777 to embeddings and vector_stores.create() endpoints. This should _not_ be a breaking change since (a) our tests were already using the `extra_body` parameter when passing in to the backend (b) but the backend probably wasn't extracting the parameters correctly. This PR will fix that. Updated APIs: `openai_embeddings(), openai_create_vector_store(), openai_create_vector_store_file_batch()`
		
			
				
	
	
		
			679 lines
		
	
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			679 lines
		
	
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import asyncio
 | |
| import hashlib
 | |
| import itertools
 | |
| import json
 | |
| import time
 | |
| import uuid
 | |
| from io import BytesIO
 | |
| from typing import Any, Literal
 | |
| 
 | |
| from openai.types.batch import BatchError, Errors
 | |
| from pydantic import BaseModel
 | |
| 
 | |
| from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
 | |
| from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
 | |
| from llama_stack.apis.files import Files, OpenAIFilePurpose
 | |
| from llama_stack.apis.inference import (
 | |
|     Inference,
 | |
|     OpenAIAssistantMessageParam,
 | |
|     OpenAIChatCompletionRequestWithExtraBody,
 | |
|     OpenAICompletionRequestWithExtraBody,
 | |
|     OpenAIDeveloperMessageParam,
 | |
|     OpenAIEmbeddingsRequestWithExtraBody,
 | |
|     OpenAIMessageParam,
 | |
|     OpenAISystemMessageParam,
 | |
|     OpenAIToolMessageParam,
 | |
|     OpenAIUserMessageParam,
 | |
| )
 | |
| from llama_stack.apis.models import Models
 | |
| from llama_stack.log import get_logger
 | |
| from llama_stack.providers.utils.kvstore import KVStore
 | |
| 
 | |
| from .config import ReferenceBatchesImplConfig
 | |
| 
 | |
| BATCH_PREFIX = "batch:"
 | |
| 
 | |
| logger = get_logger(__name__)
 | |
| 
 | |
| 
 | |
| class AsyncBytesIO:
 | |
|     """
 | |
|     Async-compatible BytesIO wrapper to allow async file-like operations.
 | |
| 
 | |
|     We use this when uploading files to the Files API, as it expects an
 | |
|     async file-like object.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, data: bytes):
 | |
|         self._buffer = BytesIO(data)
 | |
| 
 | |
|     async def read(self, n=-1):
 | |
|         return self._buffer.read(n)
 | |
| 
 | |
|     async def seek(self, pos, whence=0):
 | |
|         return self._buffer.seek(pos, whence)
 | |
| 
 | |
|     def __enter__(self):
 | |
|         return self
 | |
| 
 | |
|     def __exit__(self, exc_type, exc_val, exc_tb):
 | |
|         self._buffer.close()
 | |
| 
 | |
|     def __getattr__(self, name):
 | |
|         return getattr(self._buffer, name)
 | |
| 
 | |
| 
 | |
| class BatchRequest(BaseModel):
 | |
|     line_num: int
 | |
|     custom_id: str
 | |
|     method: str
 | |
|     url: str
 | |
|     body: dict[str, Any]
 | |
| 
 | |
| 
 | |
| def convert_to_openai_message_param(msg: dict[str, Any]) -> OpenAIMessageParam:
 | |
|     """Convert a message dictionary to OpenAIMessageParam based on role."""
 | |
|     role = msg.get("role")
 | |
| 
 | |
|     if role == "user":
 | |
|         return OpenAIUserMessageParam(**msg)
 | |
|     elif role == "system":
 | |
|         return OpenAISystemMessageParam(**msg)
 | |
|     elif role == "assistant":
 | |
|         return OpenAIAssistantMessageParam(**msg)
 | |
|     elif role == "tool":
 | |
|         return OpenAIToolMessageParam(**msg)
 | |
|     elif role == "developer":
 | |
|         return OpenAIDeveloperMessageParam(**msg)
 | |
|     else:
 | |
|         raise ValueError(f"Unknown message role: {role}")
 | |
| 
 | |
| 
 | |
| class ReferenceBatchesImpl(Batches):
 | |
|     """Reference implementation of the Batches API.
 | |
| 
 | |
|     This implementation processes batch files by making individual requests
 | |
|     to the inference API and generates output files with results.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         config: ReferenceBatchesImplConfig,
 | |
|         inference_api: Inference,
 | |
|         files_api: Files,
 | |
|         models_api: Models,
 | |
|         kvstore: KVStore,
 | |
|     ) -> None:
 | |
|         self.config = config
 | |
|         self.kvstore = kvstore
 | |
|         self.inference_api = inference_api
 | |
|         self.files_api = files_api
 | |
|         self.models_api = models_api
 | |
|         self._processing_tasks: dict[str, asyncio.Task] = {}
 | |
|         self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches)
 | |
|         self._update_batch_lock = asyncio.Lock()
 | |
| 
 | |
|         # this is to allow tests to disable background processing
 | |
|         self.process_batches = True
 | |
| 
 | |
|     async def initialize(self) -> None:
 | |
|         # TODO: start background processing of existing tasks
 | |
|         pass
 | |
| 
 | |
|     async def shutdown(self) -> None:
 | |
|         """Shutdown the batches provider."""
 | |
|         if self._processing_tasks:
 | |
|             # don't cancel tasks - just let them stop naturally on shutdown
 | |
|             # cancelling would mark batches as "cancelled" in the database
 | |
|             logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks")
 | |
| 
 | |
|     # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
 | |
|     async def create_batch(
 | |
|         self,
 | |
|         input_file_id: str,
 | |
|         endpoint: str,
 | |
|         completion_window: Literal["24h"],
 | |
|         metadata: dict[str, str] | None = None,
 | |
|         idempotency_key: str | None = None,
 | |
|     ) -> BatchObject:
 | |
|         """
 | |
|         Create a new batch for processing multiple API requests.
 | |
| 
 | |
|         This implementation provides optional idempotency: when an idempotency key
 | |
|         (idempotency_key) is provided, a deterministic ID is generated based on the input
 | |
|         parameters. If a batch with the same parameters already exists, it will be
 | |
|         returned instead of creating a duplicate. Without an idempotency key,
 | |
|         each request creates a new batch with a unique ID.
 | |
| 
 | |
|         Args:
 | |
|             input_file_id: The ID of an uploaded file containing requests for the batch.
 | |
|             endpoint: The endpoint to be used for all requests in the batch.
 | |
|             completion_window: The time window within which the batch should be processed.
 | |
|             metadata: Optional metadata for the batch.
 | |
|             idempotency_key: Optional idempotency key for enabling idempotent behavior.
 | |
| 
 | |
|         Returns:
 | |
|             The created or existing batch object.
 | |
|         """
 | |
| 
 | |
|         # Error handling by levels -
 | |
|         #  0. Input param handling, results in 40x errors before processing, e.g.
 | |
|         #    - Wrong completion_window
 | |
|         #    - Invalid metadata types
 | |
|         #    - Unknown endpoint
 | |
|         #   -> no batch created
 | |
|         #  1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
 | |
|         #    - input_file_id missing
 | |
|         #    - invalid json in file
 | |
|         #    - missing custom_id, method, url, body
 | |
|         #    - invalid model
 | |
|         #    - streaming
 | |
|         #   -> batch created, validation sends to failed status
 | |
|         #  2. Processing errors, result in error_file_id entries, e.g.
 | |
|         #    - Any error returned from inference endpoint
 | |
|         #   -> batch created, goes to completed status
 | |
| 
 | |
|         # TODO: set expiration time for garbage collection
 | |
| 
 | |
|         if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
 | |
|             raise ValueError(
 | |
|                 f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
 | |
|             )
 | |
| 
 | |
|         if completion_window != "24h":
 | |
|             raise ValueError(
 | |
|                 f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
 | |
|             )
 | |
| 
 | |
|         batch_id = f"batch_{uuid.uuid4().hex[:16]}"
 | |
| 
 | |
|         # For idempotent requests, use the idempotency key for the batch ID
 | |
|         # This ensures the same key always maps to the same batch ID,
 | |
|         # allowing us to detect parameter conflicts
 | |
|         if idempotency_key is not None:
 | |
|             hash_input = idempotency_key.encode("utf-8")
 | |
|             hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
 | |
|             batch_id = f"batch_{hash_digest}"
 | |
| 
 | |
|             try:
 | |
|                 existing_batch = await self.retrieve_batch(batch_id)
 | |
| 
 | |
|                 if (
 | |
|                     existing_batch.input_file_id != input_file_id
 | |
|                     or existing_batch.endpoint != endpoint
 | |
|                     or existing_batch.completion_window != completion_window
 | |
|                     or existing_batch.metadata != metadata
 | |
|                 ):
 | |
|                     raise ConflictError(
 | |
|                         f"Idempotency key '{idempotency_key}' was previously used with different parameters. "
 | |
|                         "Either use a new idempotency key or ensure all parameters match the original request."
 | |
|                     )
 | |
| 
 | |
|                 logger.info(f"Returning existing batch with ID: {batch_id}")
 | |
|                 return existing_batch
 | |
|             except ResourceNotFoundError:
 | |
|                 # Batch doesn't exist, continue with creation
 | |
|                 pass
 | |
| 
 | |
|         current_time = int(time.time())
 | |
| 
 | |
|         batch = BatchObject(
 | |
|             id=batch_id,
 | |
|             object="batch",
 | |
|             endpoint=endpoint,
 | |
|             input_file_id=input_file_id,
 | |
|             completion_window=completion_window,
 | |
|             status="validating",
 | |
|             created_at=current_time,
 | |
|             metadata=metadata,
 | |
|         )
 | |
| 
 | |
|         await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
 | |
|         logger.info(f"Created new batch with ID: {batch_id}")
 | |
| 
 | |
|         if self.process_batches:
 | |
|             task = asyncio.create_task(self._process_batch(batch_id))
 | |
|             self._processing_tasks[batch_id] = task
 | |
| 
 | |
|         return batch
 | |
| 
 | |
|     async def cancel_batch(self, batch_id: str) -> BatchObject:
 | |
|         """Cancel a batch that is in progress."""
 | |
|         batch = await self.retrieve_batch(batch_id)
 | |
| 
 | |
|         if batch.status in ["cancelled", "cancelling"]:
 | |
|             return batch
 | |
| 
 | |
|         if batch.status in ["completed", "failed", "expired"]:
 | |
|             raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
 | |
| 
 | |
|         await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
 | |
| 
 | |
|         if batch_id in self._processing_tasks:
 | |
|             self._processing_tasks[batch_id].cancel()
 | |
|             # note: task removal and status="cancelled" handled in finally block of _process_batch
 | |
| 
 | |
|         return await self.retrieve_batch(batch_id)
 | |
| 
 | |
|     async def list_batches(
 | |
|         self,
 | |
|         after: str | None = None,
 | |
|         limit: int = 20,
 | |
|     ) -> ListBatchesResponse:
 | |
|         """
 | |
|         List all batches, eventually only for the current user.
 | |
| 
 | |
|         With no notion of user, we return all batches.
 | |
|         """
 | |
|         batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff")
 | |
| 
 | |
|         batches = []
 | |
|         for batch_data in batch_values:
 | |
|             if batch_data:
 | |
|                 batches.append(BatchObject.model_validate_json(batch_data))
 | |
| 
 | |
|         batches.sort(key=lambda b: b.created_at, reverse=True)
 | |
| 
 | |
|         start_idx = 0
 | |
|         if after:
 | |
|             for i, batch in enumerate(batches):
 | |
|                 if batch.id == after:
 | |
|                     start_idx = i + 1
 | |
|                     break
 | |
| 
 | |
|         page_batches = batches[start_idx : start_idx + limit]
 | |
|         has_more = (start_idx + limit) < len(batches)
 | |
| 
 | |
|         first_id = page_batches[0].id if page_batches else None
 | |
|         last_id = page_batches[-1].id if page_batches else None
 | |
| 
 | |
|         return ListBatchesResponse(
 | |
|             data=page_batches,
 | |
|             first_id=first_id,
 | |
|             last_id=last_id,
 | |
|             has_more=has_more,
 | |
|         )
 | |
| 
 | |
|     async def retrieve_batch(self, batch_id: str) -> BatchObject:
 | |
|         """Retrieve information about a specific batch."""
 | |
|         batch_data = await self.kvstore.get(f"batch:{batch_id}")
 | |
|         if not batch_data:
 | |
|             raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
 | |
| 
 | |
|         return BatchObject.model_validate_json(batch_data)
 | |
| 
 | |
|     async def _update_batch(self, batch_id: str, **updates) -> None:
 | |
|         """Update batch fields in kvstore."""
 | |
|         async with self._update_batch_lock:
 | |
|             try:
 | |
|                 batch = await self.retrieve_batch(batch_id)
 | |
| 
 | |
|                 # batch processing is async. once cancelling, only allow "cancelled" status updates
 | |
|                 if batch.status == "cancelling" and updates.get("status") != "cancelled":
 | |
|                     logger.info(
 | |
|                         f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}"
 | |
|                     )
 | |
|                     return
 | |
| 
 | |
|                 if "errors" in updates:
 | |
|                     updates["errors"] = updates["errors"].model_dump()
 | |
| 
 | |
|                 batch_dict = batch.model_dump()
 | |
|                 batch_dict.update(updates)
 | |
| 
 | |
|                 await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict))
 | |
|             except Exception as e:
 | |
|                 logger.error(f"Failed to update batch {batch_id}: {e}")
 | |
| 
 | |
|     async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]:
 | |
|         """
 | |
|         Read & validate input, return errors and valid input.
 | |
| 
 | |
|         Validation of
 | |
|         - input_file_id existance
 | |
|         - valid json
 | |
|         - custom_id, method, url, body presence and valid
 | |
|         - no streaming
 | |
|         """
 | |
|         requests: list[BatchRequest] = []
 | |
|         errors: list[BatchError] = []
 | |
|         try:
 | |
|             await self.files_api.openai_retrieve_file(batch.input_file_id)
 | |
|         except Exception:
 | |
|             errors.append(
 | |
|                 BatchError(
 | |
|                     code="invalid_request",
 | |
|                     line=None,
 | |
|                     message=f"Cannot find file {batch.input_file_id}.",
 | |
|                     param="input_file_id",
 | |
|                 )
 | |
|             )
 | |
|             return errors, requests
 | |
| 
 | |
|         # TODO(SECURITY): do something about large files
 | |
|         file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
 | |
|         file_content = file_content_response.body.decode("utf-8")
 | |
|         for line_num, line in enumerate(file_content.strip().split("\n"), 1):
 | |
|             if line.strip():  # skip empty lines
 | |
|                 try:
 | |
|                     request = json.loads(line)
 | |
| 
 | |
|                     if not isinstance(request, dict):
 | |
|                         errors.append(
 | |
|                             BatchError(
 | |
|                                 code="invalid_request",
 | |
|                                 line=line_num,
 | |
|                                 message="Each line must be a JSON dictionary object",
 | |
|                             )
 | |
|                         )
 | |
|                         continue
 | |
| 
 | |
|                     valid = True
 | |
| 
 | |
|                     for param, expected_type, type_string in [
 | |
|                         ("custom_id", str, "string"),
 | |
|                         ("method", str, "string"),
 | |
|                         ("url", str, "string"),
 | |
|                         ("body", dict, "JSON dictionary object"),
 | |
|                     ]:
 | |
|                         if param not in request:
 | |
|                             errors.append(
 | |
|                                 BatchError(
 | |
|                                     code="missing_required_parameter",
 | |
|                                     line=line_num,
 | |
|                                     message=f"Missing required parameter: {param}",
 | |
|                                     param=param,
 | |
|                                 )
 | |
|                             )
 | |
|                             valid = False
 | |
|                         elif not isinstance(request[param], expected_type):
 | |
|                             param_name = "URL" if param == "url" else param.capitalize()
 | |
|                             errors.append(
 | |
|                                 BatchError(
 | |
|                                     code="invalid_request",
 | |
|                                     line=line_num,
 | |
|                                     message=f"{param_name} must be a {type_string}",
 | |
|                                     param=param,
 | |
|                                 )
 | |
|                             )
 | |
|                             valid = False
 | |
| 
 | |
|                     if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint:
 | |
|                         errors.append(
 | |
|                             BatchError(
 | |
|                                 code="invalid_url",
 | |
|                                 line=line_num,
 | |
|                                 message="URL provided for this request does not match the batch endpoint",
 | |
|                                 param="url",
 | |
|                             )
 | |
|                         )
 | |
|                         valid = False
 | |
| 
 | |
|                     if (body := request.get("body")) and isinstance(body, dict):
 | |
|                         if body.get("stream", False):
 | |
|                             errors.append(
 | |
|                                 BatchError(
 | |
|                                     code="streaming_unsupported",
 | |
|                                     line=line_num,
 | |
|                                     message="Streaming is not supported in batch processing",
 | |
|                                     param="body.stream",
 | |
|                                 )
 | |
|                             )
 | |
|                             valid = False
 | |
| 
 | |
|                         if batch.endpoint == "/v1/chat/completions":
 | |
|                             required_params: list[tuple[str, Any, str]] = [
 | |
|                                 ("model", str, "a string"),
 | |
|                                 # messages is specific to /v1/chat/completions
 | |
|                                 # we could skip validating messages here and let inference fail. however,
 | |
|                                 # that would be a very expensive way to find out messages is wrong.
 | |
|                                 ("messages", list, "an array"),  # TODO: allow messages to be a string?
 | |
|                             ]
 | |
|                         elif batch.endpoint == "/v1/completions":
 | |
|                             required_params = [
 | |
|                                 ("model", str, "a string"),
 | |
|                                 ("prompt", str, "a string"),  # TODO: allow prompt to be a list of strings??
 | |
|                             ]
 | |
|                         else:  # /v1/embeddings
 | |
|                             required_params = [
 | |
|                                 ("model", str, "a string"),
 | |
|                                 ("input", (str, list), "a string or array of strings"),
 | |
|                             ]
 | |
| 
 | |
|                         for param, expected_type, type_string in required_params:
 | |
|                             if param not in body:
 | |
|                                 errors.append(
 | |
|                                     BatchError(
 | |
|                                         code="invalid_request",
 | |
|                                         line=line_num,
 | |
|                                         message=f"{param.capitalize()} parameter is required",
 | |
|                                         param=f"body.{param}",
 | |
|                                     )
 | |
|                                 )
 | |
|                                 valid = False
 | |
|                             elif not isinstance(body[param], expected_type):
 | |
|                                 errors.append(
 | |
|                                     BatchError(
 | |
|                                         code="invalid_request",
 | |
|                                         line=line_num,
 | |
|                                         message=f"{param.capitalize()} must be {type_string}",
 | |
|                                         param=f"body.{param}",
 | |
|                                     )
 | |
|                                 )
 | |
|                                 valid = False
 | |
| 
 | |
|                         if "model" in body and isinstance(body["model"], str):
 | |
|                             try:
 | |
|                                 await self.models_api.get_model(body["model"])
 | |
|                             except Exception:
 | |
|                                 errors.append(
 | |
|                                     BatchError(
 | |
|                                         code="model_not_found",
 | |
|                                         line=line_num,
 | |
|                                         message=f"Model '{body['model']}' does not exist or is not supported",
 | |
|                                         param="body.model",
 | |
|                                     )
 | |
|                                 )
 | |
|                                 valid = False
 | |
| 
 | |
|                     if valid:
 | |
|                         assert isinstance(url, str), "URL must be a string"  # for mypy
 | |
|                         assert isinstance(body, dict), "Body must be a dictionary"  # for mypy
 | |
|                         requests.append(
 | |
|                             BatchRequest(
 | |
|                                 line_num=line_num,
 | |
|                                 url=url,
 | |
|                                 method=request["method"],
 | |
|                                 custom_id=request["custom_id"],
 | |
|                                 body=body,
 | |
|                             ),
 | |
|                         )
 | |
|                 except json.JSONDecodeError:
 | |
|                     errors.append(
 | |
|                         BatchError(
 | |
|                             code="invalid_json_line",
 | |
|                             line=line_num,
 | |
|                             message="This line is not parseable as valid JSON.",
 | |
|                         )
 | |
|                     )
 | |
| 
 | |
|         return errors, requests
 | |
| 
 | |
|     async def _process_batch(self, batch_id: str) -> None:
 | |
|         """Background task to process a batch of requests."""
 | |
|         try:
 | |
|             logger.info(f"Starting batch processing for {batch_id}")
 | |
|             async with self._batch_semaphore:  # semaphore to limit concurrency
 | |
|                 logger.info(f"Acquired semaphore for batch {batch_id}")
 | |
|                 await self._process_batch_impl(batch_id)
 | |
|         except asyncio.CancelledError:
 | |
|             logger.info(f"Batch processing cancelled for {batch_id}")
 | |
|             await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time()))
 | |
|         except Exception as e:
 | |
|             logger.error(f"Batch processing failed for {batch_id}: {e}")
 | |
|             await self._update_batch(
 | |
|                 batch_id,
 | |
|                 status="failed",
 | |
|                 failed_at=int(time.time()),
 | |
|                 errors=Errors(data=[BatchError(code="internal_error", message=str(e))]),
 | |
|             )
 | |
|         finally:
 | |
|             self._processing_tasks.pop(batch_id, None)
 | |
| 
 | |
|     async def _process_batch_impl(self, batch_id: str) -> None:
 | |
|         """Implementation of batch processing logic."""
 | |
|         errors: list[BatchError] = []
 | |
|         batch = await self.retrieve_batch(batch_id)
 | |
| 
 | |
|         errors, requests = await self._validate_input(batch)
 | |
|         if errors:
 | |
|             await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors))
 | |
|             logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors")
 | |
|             return
 | |
| 
 | |
|         logger.info(f"Processing {len(requests)} requests for batch {batch_id}")
 | |
| 
 | |
|         total_requests = len(requests)
 | |
|         await self._update_batch(
 | |
|             batch_id,
 | |
|             status="in_progress",
 | |
|             request_counts={"total": total_requests, "completed": 0, "failed": 0},
 | |
|         )
 | |
| 
 | |
|         error_results = []
 | |
|         success_results = []
 | |
|         completed_count = 0
 | |
|         failed_count = 0
 | |
| 
 | |
|         for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch):
 | |
|             # we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled
 | |
|             async with asyncio.TaskGroup() as tg:
 | |
|                 chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk]
 | |
| 
 | |
|                 chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True)
 | |
| 
 | |
|             for result in chunk_results:
 | |
|                 if isinstance(result, dict) and result.get("error") is not None:  # error response from inference
 | |
|                     failed_count += 1
 | |
|                     error_results.append(result)
 | |
|                 elif isinstance(result, dict) and result.get("response") is not None:  # successful inference
 | |
|                     completed_count += 1
 | |
|                     success_results.append(result)
 | |
|                 else:  # unexpected result
 | |
|                     failed_count += 1
 | |
|                     errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}"))
 | |
| 
 | |
|             await self._update_batch(
 | |
|                 batch_id,
 | |
|                 request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count},
 | |
|             )
 | |
| 
 | |
|             if errors:
 | |
|                 await self._update_batch(
 | |
|                     batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)
 | |
|                 )
 | |
|                 return
 | |
| 
 | |
|         try:
 | |
|             output_file_id = await self._create_output_file(batch_id, success_results, "success")
 | |
|             await self._update_batch(batch_id, output_file_id=output_file_id)
 | |
| 
 | |
|             error_file_id = await self._create_output_file(batch_id, error_results, "error")
 | |
|             await self._update_batch(batch_id, error_file_id=error_file_id)
 | |
| 
 | |
|             await self._update_batch(batch_id, status="completed", completed_at=int(time.time()))
 | |
| 
 | |
|             logger.info(
 | |
|                 f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed"
 | |
|             )
 | |
|         except Exception as e:
 | |
|             # note: errors is empty at this point, so we don't lose anything by ignoring it
 | |
|             await self._update_batch(
 | |
|                 batch_id,
 | |
|                 status="failed",
 | |
|                 failed_at=int(time.time()),
 | |
|                 errors=Errors(data=[BatchError(code="output_failed", message=str(e))]),
 | |
|             )
 | |
| 
 | |
|     async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict:
 | |
|         """Process a single request from the batch."""
 | |
|         request_id = f"batch_req_{batch_id}_{request.line_num}"
 | |
| 
 | |
|         try:
 | |
|             # TODO(SECURITY): review body for security issues
 | |
|             if request.url == "/v1/chat/completions":
 | |
|                 request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
 | |
|                 chat_params = OpenAIChatCompletionRequestWithExtraBody(**request.body)
 | |
|                 chat_response = await self.inference_api.openai_chat_completion(chat_params)
 | |
| 
 | |
|                 # this is for mypy, we don't allow streaming so we'll get the right type
 | |
|                 assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
 | |
|                 return {
 | |
|                     "id": request_id,
 | |
|                     "custom_id": request.custom_id,
 | |
|                     "response": {
 | |
|                         "status_code": 200,
 | |
|                         "request_id": request_id,  # TODO: should this be different?
 | |
|                         "body": chat_response.model_dump_json(),
 | |
|                     },
 | |
|                 }
 | |
|             elif request.url == "/v1/completions":
 | |
|                 completion_params = OpenAICompletionRequestWithExtraBody(**request.body)
 | |
|                 completion_response = await self.inference_api.openai_completion(completion_params)
 | |
| 
 | |
|                 # this is for mypy, we don't allow streaming so we'll get the right type
 | |
|                 assert hasattr(completion_response, "model_dump_json"), (
 | |
|                     "Completion response must have model_dump_json method"
 | |
|                 )
 | |
|                 return {
 | |
|                     "id": request_id,
 | |
|                     "custom_id": request.custom_id,
 | |
|                     "response": {
 | |
|                         "status_code": 200,
 | |
|                         "request_id": request_id,
 | |
|                         "body": completion_response.model_dump_json(),
 | |
|                     },
 | |
|                 }
 | |
|             else:  # /v1/embeddings
 | |
|                 embeddings_response = await self.inference_api.openai_embeddings(
 | |
|                     OpenAIEmbeddingsRequestWithExtraBody(**request.body)
 | |
|                 )
 | |
|                 assert hasattr(embeddings_response, "model_dump_json"), (
 | |
|                     "Embeddings response must have model_dump_json method"
 | |
|                 )
 | |
|                 return {
 | |
|                     "id": request_id,
 | |
|                     "custom_id": request.custom_id,
 | |
|                     "response": {
 | |
|                         "status_code": 200,
 | |
|                         "request_id": request_id,  # TODO: should this be different?
 | |
|                         "body": embeddings_response.model_dump_json(),
 | |
|                     },
 | |
|                 }
 | |
|         except Exception as e:
 | |
|             logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
 | |
|             return {
 | |
|                 "id": request_id,
 | |
|                 "custom_id": request.custom_id,
 | |
|                 "error": {"type": "request_failed", "message": str(e)},
 | |
|             }
 | |
| 
 | |
|     async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str:
 | |
|         """
 | |
|         Create an output file with batch results.
 | |
| 
 | |
|         This function filters results based on the specified file_type
 | |
|         and uploads the file to the Files API.
 | |
|         """
 | |
|         output_lines = [json.dumps(result) for result in results]
 | |
| 
 | |
|         with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer:
 | |
|             file_buffer.filename = f"{batch_id}_{file_type}.jsonl"
 | |
|             uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
 | |
|             return uploaded_file.id
 |