llama-stack-mirror/tests/unit/providers/batches/test_reference.py
Matthew Farrellee 8e678912ec feat: add batches API with OpenAI compatibility
Add complete batches API implementation with protocol, providers, and tests:

Core Infrastructure:
- Add batches API protocol using OpenAI Batch types directly
- Add Api.batches enum value and protocol mapping in resolver
- Add OpenAI "batch" file purpose support
- Include proper error handling (ConflictError, ResourceNotFoundError)

Reference Provider:
- Add ReferenceBatchesImpl with full CRUD operations (create, retrieve, cancel, list)
- Implement background batch processing with configurable concurrency
- Add SQLite KVStore backend for persistence
- Support /v1/chat/completions endpoint with request validation

Comprehensive Test Suite:
- Add unit tests for provider implementation with validation
- Add integration tests for end-to-end batch processing workflows
- Add error handling tests for validation, malformed inputs, and edge cases

Configuration:
- Add max_concurrent_batches and max_concurrent_requests_per_batch options
- Add provider documentation with sample configurations

Test with -

```
$ uv run llama stack build --image-type venv --providers inference=YOU_PICK,files=inline::localfs,batches=inline::reference --run &
$ LLAMA_STACK_CONFIG=http://localhost:8321 uv run pytest tests/unit/providers/batches tests/integration/batches --text-model YOU_PICK
```
2025-08-08 08:08:08 -04:00

753 lines
32 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Test suite for the reference implementation of the Batches API.
The tests are categorized and outlined below, keep this updated:
- Batch creation with various parameters and validation:
* test_create_and_retrieve_batch_success (positive)
* test_create_batch_without_metadata (positive)
* test_create_batch_completion_window (negative)
* test_create_batch_invalid_endpoints (negative)
* test_create_batch_invalid_metadata (negative)
- Batch retrieval and error handling for non-existent batches:
* test_retrieve_batch_not_found (negative)
- Batch cancellation with proper status transitions:
* test_cancel_batch_success (positive)
* test_cancel_batch_invalid_statuses (negative)
* test_cancel_batch_not_found (negative)
- Batch listing with pagination and filtering:
* test_list_batches_empty (positive)
* test_list_batches_single_batch (positive)
* test_list_batches_multiple_batches (positive)
* test_list_batches_with_limit (positive)
* test_list_batches_with_pagination (positive)
* test_list_batches_invalid_after (negative)
- Data persistence in the underlying key-value store:
* test_kvstore_persistence (positive)
- Batch processing concurrency control:
* test_max_concurrent_batches (positive)
- Input validation testing (direct _validate_input method tests):
* test_validate_input_file_not_found (negative)
* test_validate_input_file_exists_empty_content (positive)
* test_validate_input_file_mixed_valid_invalid_json (mixed)
* test_validate_input_invalid_model (negative)
* test_validate_input_url_mismatch (negative)
* test_validate_input_multiple_errors_per_request (negative)
* test_validate_input_invalid_request_format (negative)
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
The tests use temporary SQLite databases for isolation and mock external
dependencies like inference, files, and models APIs.
"""
import json
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.batches import BatchObject
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class TestReferenceBatchesImpl:
"""Test the reference implementation of the Batches API."""
@pytest.fixture
async def provider(self):
"""Create a test provider instance with temporary database."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_batches.db"
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
# Create kvstore and mock APIs
from unittest.mock import AsyncMock
from llama_stack.providers.utils.kvstore import kvstore_impl
kvstore = await kvstore_impl(config.kvstore)
mock_inference = AsyncMock()
mock_files = AsyncMock()
mock_models = AsyncMock()
provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
await provider.initialize()
# unit tests should not require background processing
provider.process_batches = False
yield provider
await provider.shutdown()
@pytest.fixture
def sample_batch_data(self):
"""Sample batch data for testing."""
return {
"input_file_id": "file_abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"metadata": {"test": "true", "priority": "high"},
}
def _validate_batch_type(self, batch, expected_metadata=None):
"""
Helper function to validate batch object structure and field types.
Note: This validates the direct BatchObject from the provider, not the
client library response which has a different structure.
Args:
batch: The BatchObject instance to validate.
expected_metadata: Optional expected metadata dictionary to validate against.
"""
assert isinstance(batch.id, str)
assert isinstance(batch.completion_window, str)
assert isinstance(batch.created_at, int)
assert isinstance(batch.endpoint, str)
assert isinstance(batch.input_file_id, str)
assert batch.object == "batch"
assert batch.status in [
"validating",
"failed",
"in_progress",
"finalizing",
"completed",
"expired",
"cancelling",
"cancelled",
]
if expected_metadata is not None:
assert batch.metadata == expected_metadata
timestamp_fields = [
"cancelled_at",
"cancelling_at",
"completed_at",
"expired_at",
"expires_at",
"failed_at",
"finalizing_at",
"in_progress_at",
]
for field in timestamp_fields:
field_value = getattr(batch, field, None)
if field_value is not None:
assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}"
file_id_fields = ["error_file_id", "output_file_id"]
for field in file_id_fields:
field_value = getattr(batch, field, None)
if field_value is not None:
assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}"
if hasattr(batch, "request_counts") and batch.request_counts is not None:
assert isinstance(batch.request_counts.completed, int), (
f"request_counts.completed should be int, got {type(batch.request_counts.completed)}"
)
assert isinstance(batch.request_counts.failed, int), (
f"request_counts.failed should be int, got {type(batch.request_counts.failed)}"
)
assert isinstance(batch.request_counts.total, int), (
f"request_counts.total should be int, got {type(batch.request_counts.total)}"
)
if hasattr(batch, "errors") and batch.errors is not None:
assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}"
if hasattr(batch.errors, "data") and batch.errors.data is not None:
assert isinstance(batch.errors.data, list), (
f"errors.data should be list or None, got {type(batch.errors.data)}"
)
for i, error_item in enumerate(batch.errors.data):
assert isinstance(error_item, dict), (
f"errors.data[{i}] should be object or dict, got {type(error_item)}"
)
if hasattr(error_item, "code") and error_item.code is not None:
assert isinstance(error_item.code, str), (
f"errors.data[{i}].code should be str or None, got {type(error_item.code)}"
)
if hasattr(error_item, "line") and error_item.line is not None:
assert isinstance(error_item.line, int), (
f"errors.data[{i}].line should be int or None, got {type(error_item.line)}"
)
if hasattr(error_item, "message") and error_item.message is not None:
assert isinstance(error_item.message, str), (
f"errors.data[{i}].message should be str or None, got {type(error_item.message)}"
)
if hasattr(error_item, "param") and error_item.param is not None:
assert isinstance(error_item.param, str), (
f"errors.data[{i}].param should be str or None, got {type(error_item.param)}"
)
if hasattr(batch.errors, "object") and batch.errors.object is not None:
assert isinstance(batch.errors.object, str), (
f"errors.object should be str or None, got {type(batch.errors.object)}"
)
assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}"
async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
"""Test successful batch creation and retrieval."""
created_batch = await provider.create_batch(**sample_batch_data)
self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
assert created_batch.id.startswith("batch_")
assert len(created_batch.id) > 13
assert created_batch.object == "batch"
assert created_batch.endpoint == sample_batch_data["endpoint"]
assert created_batch.input_file_id == sample_batch_data["input_file_id"]
assert created_batch.completion_window == sample_batch_data["completion_window"]
assert created_batch.status == "validating"
assert created_batch.metadata == sample_batch_data["metadata"]
assert isinstance(created_batch.created_at, int)
assert created_batch.created_at > 0
retrieved_batch = await provider.retrieve_batch(created_batch.id)
self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
assert retrieved_batch.id == created_batch.id
assert retrieved_batch.input_file_id == created_batch.input_file_id
assert retrieved_batch.endpoint == created_batch.endpoint
assert retrieved_batch.status == created_batch.status
assert retrieved_batch.metadata == created_batch.metadata
async def test_create_batch_without_metadata(self, provider):
"""Test batch creation without optional metadata."""
batch = await provider.create_batch(
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h"
)
assert batch.metadata is None
async def test_create_batch_completion_window(self, provider):
"""Test batch creation with invalid completion window."""
with pytest.raises(ValueError, match="Invalid completion_window"):
await provider.create_batch(
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
)
@pytest.mark.parametrize(
"endpoint",
[
"/v1/embeddings",
"/v1/completions",
"/v1/invalid/endpoint",
"",
],
)
async def test_create_batch_invalid_endpoints(self, provider, endpoint):
"""Test batch creation with various invalid endpoints."""
with pytest.raises(ValueError, match="Invalid endpoint"):
await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
async def test_create_batch_invalid_metadata(self, provider):
"""Test that batch creation fails with invalid metadata."""
with pytest.raises(ValueError, match="should be a valid string"):
await provider.create_batch(
input_file_id="file_123",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={123: "invalid_key"}, # Non-string key
)
with pytest.raises(ValueError, match="should be a valid string"):
await provider.create_batch(
input_file_id="file_123",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={"valid_key": 456}, # Non-string value
)
async def test_retrieve_batch_not_found(self, provider):
"""Test error when retrieving non-existent batch."""
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
await provider.retrieve_batch("nonexistent_batch")
async def test_cancel_batch_success(self, provider, sample_batch_data):
"""Test successful batch cancellation."""
created_batch = await provider.create_batch(**sample_batch_data)
assert created_batch.status == "validating"
cancelled_batch = await provider.cancel_batch(created_batch.id)
assert cancelled_batch.id == created_batch.id
assert cancelled_batch.status in ["cancelling", "cancelled"]
assert isinstance(cancelled_batch.cancelling_at, int)
assert cancelled_batch.cancelling_at >= created_batch.created_at
@pytest.mark.parametrize("status", ["failed", "expired", "completed"])
async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
"""Test error when cancelling batch in final states."""
provider.process_batches = False
created_batch = await provider.create_batch(**sample_batch_data)
# directly update status in kvstore
await provider._update_batch(created_batch.id, status=status)
with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
await provider.cancel_batch(created_batch.id)
async def test_cancel_batch_not_found(self, provider):
"""Test error when cancelling non-existent batch."""
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
await provider.cancel_batch("nonexistent_batch")
async def test_list_batches_empty(self, provider):
"""Test listing batches when none exist."""
response = await provider.list_batches()
assert response.object == "list"
assert response.data == []
assert response.first_id is None
assert response.last_id is None
assert response.has_more is False
async def test_list_batches_single_batch(self, provider, sample_batch_data):
"""Test listing batches with single batch."""
created_batch = await provider.create_batch(**sample_batch_data)
response = await provider.list_batches()
assert len(response.data) == 1
self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
assert response.data[0].id == created_batch.id
assert response.first_id == created_batch.id
assert response.last_id == created_batch.id
assert response.has_more is False
async def test_list_batches_multiple_batches(self, provider):
"""Test listing multiple batches."""
batches = [
await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
)
for i in range(3)
]
response = await provider.list_batches()
assert len(response.data) == 3
batch_ids = {batch.id for batch in response.data}
expected_ids = {batch.id for batch in batches}
assert batch_ids == expected_ids
assert response.has_more is False
assert response.first_id in expected_ids
assert response.last_id in expected_ids
async def test_list_batches_with_limit(self, provider):
"""Test listing batches with limit parameter."""
batches = [
await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
)
for i in range(3)
]
response = await provider.list_batches(limit=2)
assert len(response.data) == 2
assert response.has_more is True
assert response.first_id == response.data[0].id
assert response.last_id == response.data[1].id
batch_ids = {batch.id for batch in response.data}
expected_ids = {batch.id for batch in batches}
assert batch_ids.issubset(expected_ids)
async def test_list_batches_with_pagination(self, provider):
"""Test listing batches with pagination using 'after' parameter."""
for i in range(3):
await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
)
# Get first page
first_page = await provider.list_batches(limit=1)
assert len(first_page.data) == 1
assert first_page.has_more is True
# Get second page using 'after'
second_page = await provider.list_batches(limit=1, after=first_page.data[0].id)
assert len(second_page.data) == 1
assert second_page.data[0].id != first_page.data[0].id
# Verify we got the next batch in order
all_batches = await provider.list_batches()
expected_second_batch_id = all_batches.data[1].id
assert second_page.data[0].id == expected_second_batch_id
async def test_list_batches_invalid_after(self, provider, sample_batch_data):
"""Test listing batches with invalid 'after' parameter."""
await provider.create_batch(**sample_batch_data)
response = await provider.list_batches(after="nonexistent_batch")
# Should return all batches (no filtering when 'after' batch not found)
assert len(response.data) == 1
async def test_kvstore_persistence(self, provider, sample_batch_data):
"""Test that batches are properly persisted in kvstore."""
batch = await provider.create_batch(**sample_batch_data)
stored_data = await provider.kvstore.get(f"batch:{batch.id}")
assert stored_data is not None
stored_batch_dict = json.loads(stored_data)
assert stored_batch_dict["id"] == batch.id
assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"]
async def test_validate_input_file_not_found(self, provider):
"""Test _validate_input when input file does not exist."""
provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found"))
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="nonexistent_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_request"
assert errors[0].message == "Cannot find file nonexistent_file."
assert errors[0].param == "input_file_id"
assert errors[0].line is None
async def test_validate_input_file_exists_empty_content(self, provider):
"""Test _validate_input when file exists but is empty."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b""
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="empty_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 0
assert len(requests) == 0
async def test_validate_input_file_mixed_valid_invalid_json(self, provider):
"""Test _validate_input when file contains valid and invalid JSON lines."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
# Line 1: valid JSON with proper body args, Line 2: invalid JSON
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="mixed_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
# Should have 1 JSON parsing error from line 2, and 1 valid request from line 1
assert len(errors) == 1
assert len(requests) == 1
assert errors[0].code == "invalid_json_line"
assert errors[0].line == 2
assert errors[0].message == "This line is not parseable as valid JSON."
assert requests[0].custom_id == "req-1"
assert requests[0].method == "POST"
assert requests[0].url == "/v1/chat/completions"
assert requests[0].body["model"] == "test-model"
assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}]
async def test_validate_input_invalid_model(self, provider):
"""Test _validate_input when file contains request with non-existent model."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found"))
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="invalid_model_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "model_not_found"
assert errors[0].line == 1
assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported"
assert errors[0].param == "body.model"
@pytest.mark.parametrize(
"param_name,param_path,error_code,error_message",
[
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
("model", "body.model", "invalid_request", "Model parameter is required"),
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
],
)
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
"""Test _validate_input when file contains request with missing required parameters."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
base_request = {
"custom_id": "req-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
}
# Remove the specific parameter being tested
if "." in param_path:
top_level, nested_param = param_path.split(".", 1)
del base_request[top_level][nested_param]
else:
del base_request[param_name]
mock_response.body = json.dumps(base_request).encode()
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id=f"missing_{param_name}_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == error_code
assert errors[0].line == 1
assert errors[0].message == error_message
assert errors[0].param == param_path
async def test_validate_input_url_mismatch(self, provider):
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions", # This doesn't match the URL in the request
input_file_id="url_mismatch_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_url"
assert errors[0].line == 1
assert errors[0].message == "URL provided for this request does not match the batch endpoint"
assert errors[0].param == "url"
async def test_validate_input_multiple_errors_per_request(self, provider):
"""Test _validate_input when a single request has multiple validation errors."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
# Request missing custom_id, has invalid URL, and missing model in body
mock_response.body = (
b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}'
)
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions", # Doesn't match /v1/embeddings in request
input_file_id="multiple_errors_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) >= 2 # At least missing custom_id and URL mismatch
assert len(requests) == 0
for error in errors:
assert error.line == 1
error_codes = {error.code for error in errors}
assert "missing_required_parameter" in error_codes # missing custom_id
assert "invalid_url" in error_codes # URL mismatch
async def test_validate_input_invalid_request_format(self, provider):
"""Test _validate_input when file contains non-object JSON (array, string, number)."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b'["not", "a", "request", "object"]'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="invalid_format_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_request"
assert errors[0].line == 1
assert errors[0].message == "Each line must be a JSON dictionary object"
@pytest.mark.parametrize(
"param_name,param_path,invalid_value,error_message",
[
("custom_id", "custom_id", 12345, "Custom_id must be a string"),
("url", "url", 123, "URL must be a string"),
("method", "method", ["POST"], "Method must be a string"),
("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"),
("model", "body.model", 123, "Model must be a string"),
("messages", "body.messages", "invalid messages format", "Messages must be an array"),
],
)
async def test_validate_input_invalid_parameter_types(
self, provider, param_name, param_path, invalid_value, error_message
):
"""Test _validate_input when file contains request with parameters that have invalid types."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
base_request = {
"custom_id": "req-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
}
# Override the specific parameter with invalid value
if "." in param_path:
top_level, nested_param = param_path.split(".", 1)
base_request[top_level][nested_param] = invalid_value
else:
base_request[param_name] = invalid_value
mock_response.body = json.dumps(base_request).encode()
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id=f"invalid_{param_name}_type_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_request"
assert errors[0].line == 1
assert errors[0].message == error_message
assert errors[0].param == param_path
async def test_max_concurrent_batches(self, provider):
"""Test max_concurrent_batches configuration and concurrency control."""
import asyncio
provider._batch_semaphore = asyncio.Semaphore(2)
provider.process_batches = True # enable because we're testing background processing
active_batches = 0
async def add_and_wait(batch_id: str):
nonlocal active_batches
active_batches += 1
await asyncio.sleep(float("inf"))
# the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl,
# so we can replace _process_batch_impl with our mock to control concurrency
provider._process_batch_impl = add_and_wait
for _ in range(3):
await provider.create_batch(
input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h"
)
await asyncio.sleep(0.042) # let tasks start
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"