llama-stack-mirror/tests/integration/batches/conftest.py
Matthew Farrellee 8e678912ec feat: add batches API with OpenAI compatibility
Add complete batches API implementation with protocol, providers, and tests:

Core Infrastructure:
- Add batches API protocol using OpenAI Batch types directly
- Add Api.batches enum value and protocol mapping in resolver
- Add OpenAI "batch" file purpose support
- Include proper error handling (ConflictError, ResourceNotFoundError)

Reference Provider:
- Add ReferenceBatchesImpl with full CRUD operations (create, retrieve, cancel, list)
- Implement background batch processing with configurable concurrency
- Add SQLite KVStore backend for persistence
- Support /v1/chat/completions endpoint with request validation

Comprehensive Test Suite:
- Add unit tests for provider implementation with validation
- Add integration tests for end-to-end batch processing workflows
- Add error handling tests for validation, malformed inputs, and edge cases

Configuration:
- Add max_concurrent_batches and max_concurrent_requests_per_batch options
- Add provider documentation with sample configurations

Test with -

```
$ uv run llama stack build --image-type venv --providers inference=YOU_PICK,files=inline::localfs,batches=inline::reference --run &
$ LLAMA_STACK_CONFIG=http://localhost:8321 uv run pytest tests/unit/providers/batches tests/integration/batches --text-model YOU_PICK
```
2025-08-08 08:08:08 -04:00

122 lines
4.4 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Shared pytest fixtures for batch tests."""
import json
import time
import warnings
from contextlib import contextmanager
from io import BytesIO
import pytest
from llama_stack.apis.files import OpenAIFilePurpose
class BatchHelper:
"""Helper class for creating and managing batch input files."""
def __init__(self, client):
"""Initialize with either a batch_client or openai_client."""
self.client = client
@contextmanager
def create_file(self, content: str | list[dict], filename_prefix="batch_input"):
"""Context manager for creating and cleaning up batch input files.
Args:
content: Either a list of batch request dictionaries or raw string content
filename_prefix: Prefix for the generated filename (or full filename if content is string)
Yields:
The uploaded file object
"""
if isinstance(content, str):
# Handle raw string content (e.g., malformed JSONL, empty files)
file_content = content.encode("utf-8")
else:
# Handle list of batch request dictionaries
jsonl_content = "\n".join(json.dumps(req) for req in content)
file_content = jsonl_content.encode("utf-8")
filename = filename_prefix if filename_prefix.endswith(".jsonl") else f"{filename_prefix}.jsonl"
with BytesIO(file_content) as file_buffer:
file_buffer.name = filename
uploaded_file = self.client.files.create(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
try:
yield uploaded_file
finally:
try:
self.client.files.delete(uploaded_file.id)
except Exception:
warnings.warn(
f"Failed to cleanup file {uploaded_file.id}: {uploaded_file.filename}",
stacklevel=2,
)
def wait_for(
self,
batch_id: str,
max_wait_time: int = 60,
sleep_interval: int | None = None,
expected_statuses: set[str] | None = None,
timeout_action: str = "fail",
):
"""Wait for a batch to reach a terminal status.
Args:
batch_id: The batch ID to monitor
max_wait_time: Maximum time to wait in seconds (default: 60 seconds)
sleep_interval: Time to sleep between checks in seconds (default: 1/10th of max_wait_time, min 1s, max 15s)
expected_statuses: Set of expected terminal statuses (default: {"completed"})
timeout_action: Action on timeout - "fail" (pytest.fail) or "skip" (pytest.skip)
Returns:
The final batch object
Raises:
pytest.Failed: If batch reaches an unexpected status or timeout_action is "fail"
pytest.Skipped: If timeout_action is "skip" on timeout or unexpected status
"""
if sleep_interval is None:
# Default to 1/10th of max_wait_time, with min 1s and max 15s
sleep_interval = max(1, min(15, max_wait_time // 10))
if expected_statuses is None:
expected_statuses = {"completed"}
terminal_statuses = {"completed", "failed", "cancelled", "expired"}
unexpected_statuses = terminal_statuses - expected_statuses
start_time = time.time()
while time.time() - start_time < max_wait_time:
current_batch = self.client.batches.retrieve(batch_id)
if current_batch.status in expected_statuses:
return current_batch
elif current_batch.status in unexpected_statuses:
error_msg = f"Batch reached unexpected status: {current_batch.status}"
if timeout_action == "skip":
pytest.skip(error_msg)
else:
pytest.fail(error_msg)
time.sleep(sleep_interval)
timeout_msg = f"Batch did not reach expected status {expected_statuses} within {max_wait_time} seconds"
if timeout_action == "skip":
pytest.skip(timeout_msg)
else:
pytest.fail(timeout_msg)
@pytest.fixture
def batch_helper(openai_client):
"""Fixture that provides a BatchHelper instance for OpenAI client."""
return BatchHelper(openai_client)