mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 09:05:37 +00:00 
			
		
		
		
	# What does this PR do? This PR extends the Llama Stack Batches API to support the /v1/embeddings endpoint, enabling efficient batch processing of embedding requests alongside the existing /v1/chat/completions and /v1/completions support. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> Closes: https://github.com/llamastack/llama-stack/issues/3145 ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> ``` (stack-client) ➜ llama-stack git:(support/embeddings-api) conda activate stack-client && python -m pytest tests/unit/providers/batches/test_reference.py -v ============================================================================================================================================ test session starts ============================================================================================================================================= platform darwin -- Python 3.12.11, pytest-7.4.4, pluggy-1.5.0 -- /Users/vnarsing/miniconda3/envs/stack-client/bin/python cachedir: .pytest_cache metadata: {'Python': '3.12.11', 'Platform': 'macOS-15.6.1-arm64-arm-64bit', 'Packages': {'pytest': '7.4.4', 'pluggy': '1.5.0'}, 'Plugins': {'asyncio': '0.23.8', 'cov': '6.0.0', 'timeout': '2.2.0', 'socket': '0.7.0', 'xdist': '3.8.0', 'html': '3.1.1', 'langsmith': '0.3.39', 'anyio': '4.8.0', 'metadata': '3.0.0'}} rootdir: /Users/vnarsing/go/src/github/meta-llama/llama-stack configfile: pyproject.toml plugins: asyncio-0.23.8, cov-6.0.0, timeout-2.2.0, socket-0.7.0, xdist-3.8.0, html-3.1.1, langsmith-0.3.39, anyio-4.8.0, metadata-3.0.0 asyncio: mode=Mode.AUTO collected 46 items tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_create_and_retrieve_batch_success PASSED [ 2%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_create_batch_without_metadata PASSED [ 4%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_create_batch_completion_window PASSED [ 6%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_create_batch_invalid_endpoints[/v1/invalid/endpoint] PASSED [ 8%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_create_batch_invalid_endpoints[] PASSED [ 10%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_create_batch_invalid_metadata PASSED [ 13%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_retrieve_batch_not_found PASSED [ 15%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_cancel_batch_success PASSED [ 17%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_cancel_batch_invalid_statuses[failed] PASSED [ 19%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_cancel_batch_invalid_statuses[expired] PASSED [ 21%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_cancel_batch_invalid_statuses[completed] PASSED [ 23%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_cancel_batch_not_found PASSED [ 26%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_list_batches_empty PASSED [ 28%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_list_batches_single_batch PASSED [ 30%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_list_batches_multiple_batches PASSED [ 32%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_list_batches_with_limit PASSED [ 34%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_list_batches_with_pagination PASSED [ 36%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_list_batches_invalid_after PASSED [ 39%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_kvstore_persistence PASSED [ 41%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_file_not_found PASSED [ 43%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_file_exists_empty_content PASSED [ 45%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_file_mixed_valid_invalid_json PASSED [ 47%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_invalid_model PASSED [ 50%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_missing_parameters_chat_completions[custom_id-custom_id-missing_required_parameter-Missing required parameter: custom_id] PASSED [ 52%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_missing_parameters_chat_completions[method-method-missing_required_parameter-Missing required parameter: method] PASSED [ 54%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_missing_parameters_chat_completions[url-url-missing_required_parameter-Missing required parameter: url] PASSED [ 56%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_missing_parameters_chat_completions[body-body-missing_required_parameter-Missing required parameter: body] PASSED [ 58%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_missing_parameters_chat_completions[model-body.model-invalid_request-Model parameter is required] PASSED [ 60%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_missing_parameters_chat_completions[messages-body.messages-invalid_request-Messages parameter is required] PASSED [ 63%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_missing_parameters_completions[custom_id-custom_id-missing_required_parameter-Missing required parameter: custom_id] PASSED [ 65%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_missing_parameters_completions[method-method-missing_required_parameter-Missing required parameter: method] PASSED [ 67%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_missing_parameters_completions[url-url-missing_required_parameter-Missing required parameter: url] PASSED [ 69%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_missing_parameters_completions[body-body-missing_required_parameter-Missing required parameter: body] PASSED [ 71%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_missing_parameters_completions[model-body.model-invalid_request-Model parameter is required] PASSED [ 73%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_missing_parameters_completions[prompt-body.prompt-invalid_request-Prompt parameter is required] PASSED [ 76%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_url_mismatch PASSED [ 78%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_multiple_errors_per_request PASSED [ 80%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_invalid_request_format PASSED [ 82%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_invalid_parameter_types[custom_id-custom_id-12345-Custom_id must be a string] PASSED [ 84%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_invalid_parameter_types[url-url-123-URL must be a string] PASSED [ 86%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_invalid_parameter_types[method-method-invalid_value2-Method must be a string] PASSED [ 89%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_invalid_parameter_types[body-body-invalid_value3-Body must be a JSON dictionary object] PASSED [ 91%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_invalid_parameter_types[model-body.model-123-Model must be a string] PASSED [ 93%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_validate_input_invalid_parameter_types[messages-body.messages-invalid messages format-Messages must be an array] PASSED [ 95%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_max_concurrent_batches PASSED [ 97%] tests/unit/providers/batches/test_reference.py::TestReferenceBatchesImpl::test_create_batch_embeddings_endpoint PASSED [100%] ``` --------- Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
		
			
				
	
	
		
			775 lines
		
	
	
	
		
			33 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			775 lines
		
	
	
	
		
			33 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| """
 | |
| Test suite for the reference implementation of the Batches API.
 | |
| 
 | |
| The tests are categorized and outlined below, keep this updated:
 | |
| 
 | |
| - Batch creation with various parameters and validation:
 | |
|   * test_create_and_retrieve_batch_success (positive)
 | |
|   * test_create_batch_without_metadata (positive)
 | |
|   * test_create_batch_completion_window (negative)
 | |
|   * test_create_batch_invalid_endpoints (negative)
 | |
|   * test_create_batch_invalid_metadata (negative)
 | |
| 
 | |
| - Batch retrieval and error handling for non-existent batches:
 | |
|   * test_retrieve_batch_not_found (negative)
 | |
| 
 | |
| - Batch cancellation with proper status transitions:
 | |
|   * test_cancel_batch_success (positive)
 | |
|   * test_cancel_batch_invalid_statuses (negative)
 | |
|   * test_cancel_batch_not_found (negative)
 | |
| 
 | |
| - Batch listing with pagination and filtering:
 | |
|   * test_list_batches_empty (positive)
 | |
|   * test_list_batches_single_batch (positive)
 | |
|   * test_list_batches_multiple_batches (positive)
 | |
|   * test_list_batches_with_limit (positive)
 | |
|   * test_list_batches_with_pagination (positive)
 | |
|   * test_list_batches_invalid_after (negative)
 | |
| 
 | |
| - Data persistence in the underlying key-value store:
 | |
|   * test_kvstore_persistence (positive)
 | |
| 
 | |
| - Batch processing concurrency control:
 | |
|   * test_max_concurrent_batches (positive)
 | |
| 
 | |
| - Input validation testing (direct _validate_input method tests):
 | |
|   * test_validate_input_file_not_found (negative)
 | |
|   * test_validate_input_file_exists_empty_content (positive)
 | |
|   * test_validate_input_file_mixed_valid_invalid_json (mixed)
 | |
|   * test_validate_input_invalid_model (negative)
 | |
|   * test_validate_input_url_mismatch (negative)
 | |
|   * test_validate_input_multiple_errors_per_request (negative)
 | |
|   * test_validate_input_invalid_request_format (negative)
 | |
|   * test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions)
 | |
|   * test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions)
 | |
|   * test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
 | |
| 
 | |
| The tests use temporary SQLite databases for isolation and mock external
 | |
| dependencies like inference, files, and models APIs.
 | |
| """
 | |
| 
 | |
| import json
 | |
| from unittest.mock import AsyncMock, MagicMock
 | |
| 
 | |
| import pytest
 | |
| 
 | |
| from llama_stack.apis.batches import BatchObject
 | |
| from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
 | |
| 
 | |
| 
 | |
| class TestReferenceBatchesImpl:
 | |
|     """Test the reference implementation of the Batches API."""
 | |
| 
 | |
|     def _validate_batch_type(self, batch, expected_metadata=None):
 | |
|         """
 | |
|         Helper function to validate batch object structure and field types.
 | |
| 
 | |
|         Note: This validates the direct BatchObject from the provider, not the
 | |
|               client library response which has a different structure.
 | |
| 
 | |
|         Args:
 | |
|             batch: The BatchObject instance to validate.
 | |
|             expected_metadata: Optional expected metadata dictionary to validate against.
 | |
|         """
 | |
|         assert isinstance(batch.id, str)
 | |
|         assert isinstance(batch.completion_window, str)
 | |
|         assert isinstance(batch.created_at, int)
 | |
|         assert isinstance(batch.endpoint, str)
 | |
|         assert isinstance(batch.input_file_id, str)
 | |
|         assert batch.object == "batch"
 | |
|         assert batch.status in [
 | |
|             "validating",
 | |
|             "failed",
 | |
|             "in_progress",
 | |
|             "finalizing",
 | |
|             "completed",
 | |
|             "expired",
 | |
|             "cancelling",
 | |
|             "cancelled",
 | |
|         ]
 | |
| 
 | |
|         if expected_metadata is not None:
 | |
|             assert batch.metadata == expected_metadata
 | |
| 
 | |
|         timestamp_fields = [
 | |
|             "cancelled_at",
 | |
|             "cancelling_at",
 | |
|             "completed_at",
 | |
|             "expired_at",
 | |
|             "expires_at",
 | |
|             "failed_at",
 | |
|             "finalizing_at",
 | |
|             "in_progress_at",
 | |
|         ]
 | |
|         for field in timestamp_fields:
 | |
|             field_value = getattr(batch, field, None)
 | |
|             if field_value is not None:
 | |
|                 assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}"
 | |
| 
 | |
|         file_id_fields = ["error_file_id", "output_file_id"]
 | |
|         for field in file_id_fields:
 | |
|             field_value = getattr(batch, field, None)
 | |
|             if field_value is not None:
 | |
|                 assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}"
 | |
| 
 | |
|         if hasattr(batch, "request_counts") and batch.request_counts is not None:
 | |
|             assert isinstance(batch.request_counts.completed, int), (
 | |
|                 f"request_counts.completed should be int, got {type(batch.request_counts.completed)}"
 | |
|             )
 | |
|             assert isinstance(batch.request_counts.failed, int), (
 | |
|                 f"request_counts.failed should be int, got {type(batch.request_counts.failed)}"
 | |
|             )
 | |
|             assert isinstance(batch.request_counts.total, int), (
 | |
|                 f"request_counts.total should be int, got {type(batch.request_counts.total)}"
 | |
|             )
 | |
| 
 | |
|         if hasattr(batch, "errors") and batch.errors is not None:
 | |
|             assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}"
 | |
| 
 | |
|             if hasattr(batch.errors, "data") and batch.errors.data is not None:
 | |
|                 assert isinstance(batch.errors.data, list), (
 | |
|                     f"errors.data should be list or None, got {type(batch.errors.data)}"
 | |
|                 )
 | |
| 
 | |
|                 for i, error_item in enumerate(batch.errors.data):
 | |
|                     assert isinstance(error_item, dict), (
 | |
|                         f"errors.data[{i}] should be object or dict, got {type(error_item)}"
 | |
|                     )
 | |
| 
 | |
|                     if hasattr(error_item, "code") and error_item.code is not None:
 | |
|                         assert isinstance(error_item.code, str), (
 | |
|                             f"errors.data[{i}].code should be str or None, got {type(error_item.code)}"
 | |
|                         )
 | |
| 
 | |
|                     if hasattr(error_item, "line") and error_item.line is not None:
 | |
|                         assert isinstance(error_item.line, int), (
 | |
|                             f"errors.data[{i}].line should be int or None, got {type(error_item.line)}"
 | |
|                         )
 | |
| 
 | |
|                     if hasattr(error_item, "message") and error_item.message is not None:
 | |
|                         assert isinstance(error_item.message, str), (
 | |
|                             f"errors.data[{i}].message should be str or None, got {type(error_item.message)}"
 | |
|                         )
 | |
| 
 | |
|                     if hasattr(error_item, "param") and error_item.param is not None:
 | |
|                         assert isinstance(error_item.param, str), (
 | |
|                             f"errors.data[{i}].param should be str or None, got {type(error_item.param)}"
 | |
|                         )
 | |
| 
 | |
|             if hasattr(batch.errors, "object") and batch.errors.object is not None:
 | |
|                 assert isinstance(batch.errors.object, str), (
 | |
|                     f"errors.object should be str or None, got {type(batch.errors.object)}"
 | |
|                 )
 | |
|                 assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}"
 | |
| 
 | |
|     async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
 | |
|         """Test successful batch creation and retrieval."""
 | |
|         created_batch = await provider.create_batch(**sample_batch_data)
 | |
| 
 | |
|         self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
 | |
| 
 | |
|         assert created_batch.id.startswith("batch_")
 | |
|         assert len(created_batch.id) > 13
 | |
|         assert created_batch.object == "batch"
 | |
|         assert created_batch.endpoint == sample_batch_data["endpoint"]
 | |
|         assert created_batch.input_file_id == sample_batch_data["input_file_id"]
 | |
|         assert created_batch.completion_window == sample_batch_data["completion_window"]
 | |
|         assert created_batch.status == "validating"
 | |
|         assert created_batch.metadata == sample_batch_data["metadata"]
 | |
|         assert isinstance(created_batch.created_at, int)
 | |
|         assert created_batch.created_at > 0
 | |
| 
 | |
|         retrieved_batch = await provider.retrieve_batch(created_batch.id)
 | |
| 
 | |
|         self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
 | |
| 
 | |
|         assert retrieved_batch.id == created_batch.id
 | |
|         assert retrieved_batch.input_file_id == created_batch.input_file_id
 | |
|         assert retrieved_batch.endpoint == created_batch.endpoint
 | |
|         assert retrieved_batch.status == created_batch.status
 | |
|         assert retrieved_batch.metadata == created_batch.metadata
 | |
| 
 | |
|     async def test_create_batch_without_metadata(self, provider):
 | |
|         """Test batch creation without optional metadata."""
 | |
|         batch = await provider.create_batch(
 | |
|             input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h"
 | |
|         )
 | |
| 
 | |
|         assert batch.metadata is None
 | |
| 
 | |
|     async def test_create_batch_completion_window(self, provider):
 | |
|         """Test batch creation with invalid completion window."""
 | |
|         with pytest.raises(ValueError, match="Invalid completion_window"):
 | |
|             await provider.create_batch(
 | |
|                 input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
 | |
|             )
 | |
| 
 | |
|     @pytest.mark.parametrize(
 | |
|         "endpoint",
 | |
|         [
 | |
|             "/v1/invalid/endpoint",
 | |
|             "",
 | |
|         ],
 | |
|     )
 | |
|     async def test_create_batch_invalid_endpoints(self, provider, endpoint):
 | |
|         """Test batch creation with various invalid endpoints."""
 | |
|         with pytest.raises(ValueError, match="Invalid endpoint"):
 | |
|             await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
 | |
| 
 | |
|     async def test_create_batch_invalid_metadata(self, provider):
 | |
|         """Test that batch creation fails with invalid metadata."""
 | |
|         with pytest.raises(ValueError, match="should be a valid string"):
 | |
|             await provider.create_batch(
 | |
|                 input_file_id="file_123",
 | |
|                 endpoint="/v1/chat/completions",
 | |
|                 completion_window="24h",
 | |
|                 metadata={123: "invalid_key"},  # Non-string key
 | |
|             )
 | |
| 
 | |
|         with pytest.raises(ValueError, match="should be a valid string"):
 | |
|             await provider.create_batch(
 | |
|                 input_file_id="file_123",
 | |
|                 endpoint="/v1/chat/completions",
 | |
|                 completion_window="24h",
 | |
|                 metadata={"valid_key": 456},  # Non-string value
 | |
|             )
 | |
| 
 | |
|     async def test_retrieve_batch_not_found(self, provider):
 | |
|         """Test error when retrieving non-existent batch."""
 | |
|         with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
 | |
|             await provider.retrieve_batch("nonexistent_batch")
 | |
| 
 | |
|     async def test_cancel_batch_success(self, provider, sample_batch_data):
 | |
|         """Test successful batch cancellation."""
 | |
|         created_batch = await provider.create_batch(**sample_batch_data)
 | |
|         assert created_batch.status == "validating"
 | |
| 
 | |
|         cancelled_batch = await provider.cancel_batch(created_batch.id)
 | |
| 
 | |
|         assert cancelled_batch.id == created_batch.id
 | |
|         assert cancelled_batch.status in ["cancelling", "cancelled"]
 | |
|         assert isinstance(cancelled_batch.cancelling_at, int)
 | |
|         assert cancelled_batch.cancelling_at >= created_batch.created_at
 | |
| 
 | |
|     @pytest.mark.parametrize("status", ["failed", "expired", "completed"])
 | |
|     async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
 | |
|         """Test error when cancelling batch in final states."""
 | |
|         provider.process_batches = False
 | |
|         created_batch = await provider.create_batch(**sample_batch_data)
 | |
| 
 | |
|         # directly update status in kvstore
 | |
|         await provider._update_batch(created_batch.id, status=status)
 | |
| 
 | |
|         with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
 | |
|             await provider.cancel_batch(created_batch.id)
 | |
| 
 | |
|     async def test_cancel_batch_not_found(self, provider):
 | |
|         """Test error when cancelling non-existent batch."""
 | |
|         with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
 | |
|             await provider.cancel_batch("nonexistent_batch")
 | |
| 
 | |
|     async def test_list_batches_empty(self, provider):
 | |
|         """Test listing batches when none exist."""
 | |
|         response = await provider.list_batches()
 | |
| 
 | |
|         assert response.object == "list"
 | |
|         assert response.data == []
 | |
|         assert response.first_id is None
 | |
|         assert response.last_id is None
 | |
|         assert response.has_more is False
 | |
| 
 | |
|     async def test_list_batches_single_batch(self, provider, sample_batch_data):
 | |
|         """Test listing batches with single batch."""
 | |
|         created_batch = await provider.create_batch(**sample_batch_data)
 | |
| 
 | |
|         response = await provider.list_batches()
 | |
| 
 | |
|         assert len(response.data) == 1
 | |
|         self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
 | |
|         assert response.data[0].id == created_batch.id
 | |
|         assert response.first_id == created_batch.id
 | |
|         assert response.last_id == created_batch.id
 | |
|         assert response.has_more is False
 | |
| 
 | |
|     async def test_list_batches_multiple_batches(self, provider):
 | |
|         """Test listing multiple batches."""
 | |
|         batches = [
 | |
|             await provider.create_batch(
 | |
|                 input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
 | |
|             )
 | |
|             for i in range(3)
 | |
|         ]
 | |
| 
 | |
|         response = await provider.list_batches()
 | |
| 
 | |
|         assert len(response.data) == 3
 | |
| 
 | |
|         batch_ids = {batch.id for batch in response.data}
 | |
|         expected_ids = {batch.id for batch in batches}
 | |
|         assert batch_ids == expected_ids
 | |
|         assert response.has_more is False
 | |
| 
 | |
|         assert response.first_id in expected_ids
 | |
|         assert response.last_id in expected_ids
 | |
| 
 | |
|     async def test_list_batches_with_limit(self, provider):
 | |
|         """Test listing batches with limit parameter."""
 | |
|         batches = [
 | |
|             await provider.create_batch(
 | |
|                 input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
 | |
|             )
 | |
|             for i in range(3)
 | |
|         ]
 | |
| 
 | |
|         response = await provider.list_batches(limit=2)
 | |
| 
 | |
|         assert len(response.data) == 2
 | |
|         assert response.has_more is True
 | |
|         assert response.first_id == response.data[0].id
 | |
|         assert response.last_id == response.data[1].id
 | |
|         batch_ids = {batch.id for batch in response.data}
 | |
|         expected_ids = {batch.id for batch in batches}
 | |
|         assert batch_ids.issubset(expected_ids)
 | |
| 
 | |
|     async def test_list_batches_with_pagination(self, provider):
 | |
|         """Test listing batches with pagination using 'after' parameter."""
 | |
|         for i in range(3):
 | |
|             await provider.create_batch(
 | |
|                 input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
 | |
|             )
 | |
| 
 | |
|         # Get first page
 | |
|         first_page = await provider.list_batches(limit=1)
 | |
|         assert len(first_page.data) == 1
 | |
|         assert first_page.has_more is True
 | |
| 
 | |
|         # Get second page using 'after'
 | |
|         second_page = await provider.list_batches(limit=1, after=first_page.data[0].id)
 | |
|         assert len(second_page.data) == 1
 | |
|         assert second_page.data[0].id != first_page.data[0].id
 | |
| 
 | |
|         # Verify we got the next batch in order
 | |
|         all_batches = await provider.list_batches()
 | |
|         expected_second_batch_id = all_batches.data[1].id
 | |
|         assert second_page.data[0].id == expected_second_batch_id
 | |
| 
 | |
|     async def test_list_batches_invalid_after(self, provider, sample_batch_data):
 | |
|         """Test listing batches with invalid 'after' parameter."""
 | |
|         await provider.create_batch(**sample_batch_data)
 | |
| 
 | |
|         response = await provider.list_batches(after="nonexistent_batch")
 | |
| 
 | |
|         # Should return all batches (no filtering when 'after' batch not found)
 | |
|         assert len(response.data) == 1
 | |
| 
 | |
|     async def test_kvstore_persistence(self, provider, sample_batch_data):
 | |
|         """Test that batches are properly persisted in kvstore."""
 | |
|         batch = await provider.create_batch(**sample_batch_data)
 | |
| 
 | |
|         stored_data = await provider.kvstore.get(f"batch:{batch.id}")
 | |
|         assert stored_data is not None
 | |
| 
 | |
|         stored_batch_dict = json.loads(stored_data)
 | |
|         assert stored_batch_dict["id"] == batch.id
 | |
|         assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"]
 | |
| 
 | |
|     async def test_validate_input_file_not_found(self, provider):
 | |
|         """Test _validate_input when input file does not exist."""
 | |
|         provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found"))
 | |
| 
 | |
|         batch = BatchObject(
 | |
|             id="batch_test",
 | |
|             object="batch",
 | |
|             endpoint="/v1/chat/completions",
 | |
|             input_file_id="nonexistent_file",
 | |
|             completion_window="24h",
 | |
|             status="validating",
 | |
|             created_at=1234567890,
 | |
|         )
 | |
| 
 | |
|         errors, requests = await provider._validate_input(batch)
 | |
| 
 | |
|         assert len(errors) == 1
 | |
|         assert len(requests) == 0
 | |
|         assert errors[0].code == "invalid_request"
 | |
|         assert errors[0].message == "Cannot find file nonexistent_file."
 | |
|         assert errors[0].param == "input_file_id"
 | |
|         assert errors[0].line is None
 | |
| 
 | |
|     async def test_validate_input_file_exists_empty_content(self, provider):
 | |
|         """Test _validate_input when file exists but is empty."""
 | |
|         provider.files_api.openai_retrieve_file = AsyncMock()
 | |
|         mock_response = MagicMock()
 | |
|         mock_response.body = b""
 | |
|         provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
 | |
| 
 | |
|         batch = BatchObject(
 | |
|             id="batch_test",
 | |
|             object="batch",
 | |
|             endpoint="/v1/chat/completions",
 | |
|             input_file_id="empty_file",
 | |
|             completion_window="24h",
 | |
|             status="validating",
 | |
|             created_at=1234567890,
 | |
|         )
 | |
| 
 | |
|         errors, requests = await provider._validate_input(batch)
 | |
| 
 | |
|         assert len(errors) == 0
 | |
|         assert len(requests) == 0
 | |
| 
 | |
|     async def test_validate_input_file_mixed_valid_invalid_json(self, provider):
 | |
|         """Test _validate_input when file contains valid and invalid JSON lines."""
 | |
|         provider.files_api.openai_retrieve_file = AsyncMock()
 | |
|         mock_response = MagicMock()
 | |
|         # Line 1: valid JSON with proper body args, Line 2: invalid JSON
 | |
|         mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json'
 | |
|         provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
 | |
| 
 | |
|         batch = BatchObject(
 | |
|             id="batch_test",
 | |
|             object="batch",
 | |
|             endpoint="/v1/chat/completions",
 | |
|             input_file_id="mixed_file",
 | |
|             completion_window="24h",
 | |
|             status="validating",
 | |
|             created_at=1234567890,
 | |
|         )
 | |
| 
 | |
|         errors, requests = await provider._validate_input(batch)
 | |
| 
 | |
|         # Should have 1 JSON parsing error from line 2, and 1 valid request from line 1
 | |
|         assert len(errors) == 1
 | |
|         assert len(requests) == 1
 | |
| 
 | |
|         assert errors[0].code == "invalid_json_line"
 | |
|         assert errors[0].line == 2
 | |
|         assert errors[0].message == "This line is not parseable as valid JSON."
 | |
| 
 | |
|         assert requests[0].custom_id == "req-1"
 | |
|         assert requests[0].method == "POST"
 | |
|         assert requests[0].url == "/v1/chat/completions"
 | |
|         assert requests[0].body["model"] == "test-model"
 | |
|         assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}]
 | |
| 
 | |
|     async def test_validate_input_invalid_model(self, provider):
 | |
|         """Test _validate_input when file contains request with non-existent model."""
 | |
|         provider.files_api.openai_retrieve_file = AsyncMock()
 | |
|         mock_response = MagicMock()
 | |
|         mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}'
 | |
|         provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
 | |
| 
 | |
|         provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found"))
 | |
| 
 | |
|         batch = BatchObject(
 | |
|             id="batch_test",
 | |
|             object="batch",
 | |
|             endpoint="/v1/chat/completions",
 | |
|             input_file_id="invalid_model_file",
 | |
|             completion_window="24h",
 | |
|             status="validating",
 | |
|             created_at=1234567890,
 | |
|         )
 | |
| 
 | |
|         errors, requests = await provider._validate_input(batch)
 | |
| 
 | |
|         assert len(errors) == 1
 | |
|         assert len(requests) == 0
 | |
| 
 | |
|         assert errors[0].code == "model_not_found"
 | |
|         assert errors[0].line == 1
 | |
|         assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported"
 | |
|         assert errors[0].param == "body.model"
 | |
| 
 | |
|     @pytest.mark.parametrize(
 | |
|         "param_name,param_path,error_code,error_message",
 | |
|         [
 | |
|             ("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
 | |
|             ("method", "method", "missing_required_parameter", "Missing required parameter: method"),
 | |
|             ("url", "url", "missing_required_parameter", "Missing required parameter: url"),
 | |
|             ("body", "body", "missing_required_parameter", "Missing required parameter: body"),
 | |
|             ("model", "body.model", "invalid_request", "Model parameter is required"),
 | |
|             ("messages", "body.messages", "invalid_request", "Messages parameter is required"),
 | |
|         ],
 | |
|     )
 | |
|     async def test_validate_input_missing_parameters_chat_completions(
 | |
|         self, provider, param_name, param_path, error_code, error_message
 | |
|     ):
 | |
|         """Test _validate_input when file contains request with missing required parameters for chat completions."""
 | |
|         provider.files_api.openai_retrieve_file = AsyncMock()
 | |
|         mock_response = MagicMock()
 | |
| 
 | |
|         base_request = {
 | |
|             "custom_id": "req-1",
 | |
|             "method": "POST",
 | |
|             "url": "/v1/chat/completions",
 | |
|             "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
 | |
|         }
 | |
| 
 | |
|         # Remove the specific parameter being tested
 | |
|         if "." in param_path:
 | |
|             top_level, nested_param = param_path.split(".", 1)
 | |
|             del base_request[top_level][nested_param]
 | |
|         else:
 | |
|             del base_request[param_name]
 | |
| 
 | |
|         mock_response.body = json.dumps(base_request).encode()
 | |
|         provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
 | |
| 
 | |
|         batch = BatchObject(
 | |
|             id="batch_test",
 | |
|             object="batch",
 | |
|             endpoint="/v1/chat/completions",
 | |
|             input_file_id=f"missing_{param_name}_file",
 | |
|             completion_window="24h",
 | |
|             status="validating",
 | |
|             created_at=1234567890,
 | |
|         )
 | |
| 
 | |
|         errors, requests = await provider._validate_input(batch)
 | |
| 
 | |
|         assert len(errors) == 1
 | |
|         assert len(requests) == 0
 | |
| 
 | |
|         assert errors[0].code == error_code
 | |
|         assert errors[0].line == 1
 | |
|         assert errors[0].message == error_message
 | |
|         assert errors[0].param == param_path
 | |
| 
 | |
|     @pytest.mark.parametrize(
 | |
|         "param_name,param_path,error_code,error_message",
 | |
|         [
 | |
|             ("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
 | |
|             ("method", "method", "missing_required_parameter", "Missing required parameter: method"),
 | |
|             ("url", "url", "missing_required_parameter", "Missing required parameter: url"),
 | |
|             ("body", "body", "missing_required_parameter", "Missing required parameter: body"),
 | |
|             ("model", "body.model", "invalid_request", "Model parameter is required"),
 | |
|             ("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"),
 | |
|         ],
 | |
|     )
 | |
|     async def test_validate_input_missing_parameters_completions(
 | |
|         self, provider, param_name, param_path, error_code, error_message
 | |
|     ):
 | |
|         """Test _validate_input when file contains request with missing required parameters for text completions."""
 | |
|         provider.files_api.openai_retrieve_file = AsyncMock()
 | |
|         mock_response = MagicMock()
 | |
| 
 | |
|         base_request = {
 | |
|             "custom_id": "req-1",
 | |
|             "method": "POST",
 | |
|             "url": "/v1/completions",
 | |
|             "body": {"model": "test-model", "prompt": "Hello"},
 | |
|         }
 | |
| 
 | |
|         # Remove the specific parameter being tested
 | |
|         if "." in param_path:
 | |
|             top_level, nested_param = param_path.split(".", 1)
 | |
|             del base_request[top_level][nested_param]
 | |
|         else:
 | |
|             del base_request[param_name]
 | |
| 
 | |
|         mock_response.body = json.dumps(base_request).encode()
 | |
|         provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
 | |
| 
 | |
|         batch = BatchObject(
 | |
|             id="batch_test",
 | |
|             object="batch",
 | |
|             endpoint="/v1/completions",
 | |
|             input_file_id=f"missing_{param_name}_file",
 | |
|             completion_window="24h",
 | |
|             status="validating",
 | |
|             created_at=1234567890,
 | |
|         )
 | |
| 
 | |
|         errors, requests = await provider._validate_input(batch)
 | |
| 
 | |
|         assert len(errors) == 1
 | |
|         assert len(requests) == 0
 | |
| 
 | |
|         assert errors[0].code == error_code
 | |
|         assert errors[0].line == 1
 | |
|         assert errors[0].message == error_message
 | |
|         assert errors[0].param == param_path
 | |
| 
 | |
|     async def test_validate_input_url_mismatch(self, provider):
 | |
|         """Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
 | |
|         provider.files_api.openai_retrieve_file = AsyncMock()
 | |
|         mock_response = MagicMock()
 | |
|         mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}'
 | |
|         provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
 | |
| 
 | |
|         batch = BatchObject(
 | |
|             id="batch_test",
 | |
|             object="batch",
 | |
|             endpoint="/v1/chat/completions",  # This doesn't match the URL in the request
 | |
|             input_file_id="url_mismatch_file",
 | |
|             completion_window="24h",
 | |
|             status="validating",
 | |
|             created_at=1234567890,
 | |
|         )
 | |
| 
 | |
|         errors, requests = await provider._validate_input(batch)
 | |
| 
 | |
|         assert len(errors) == 1
 | |
|         assert len(requests) == 0
 | |
| 
 | |
|         assert errors[0].code == "invalid_url"
 | |
|         assert errors[0].line == 1
 | |
|         assert errors[0].message == "URL provided for this request does not match the batch endpoint"
 | |
|         assert errors[0].param == "url"
 | |
| 
 | |
|     async def test_validate_input_multiple_errors_per_request(self, provider):
 | |
|         """Test _validate_input when a single request has multiple validation errors."""
 | |
|         provider.files_api.openai_retrieve_file = AsyncMock()
 | |
|         mock_response = MagicMock()
 | |
|         # Request missing custom_id, has invalid URL, and missing model in body
 | |
|         mock_response.body = (
 | |
|             b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}'
 | |
|         )
 | |
|         provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
 | |
| 
 | |
|         batch = BatchObject(
 | |
|             id="batch_test",
 | |
|             object="batch",
 | |
|             endpoint="/v1/chat/completions",  # Doesn't match /v1/embeddings in request
 | |
|             input_file_id="multiple_errors_file",
 | |
|             completion_window="24h",
 | |
|             status="validating",
 | |
|             created_at=1234567890,
 | |
|         )
 | |
| 
 | |
|         errors, requests = await provider._validate_input(batch)
 | |
| 
 | |
|         assert len(errors) >= 2  # At least missing custom_id and URL mismatch
 | |
|         assert len(requests) == 0
 | |
| 
 | |
|         for error in errors:
 | |
|             assert error.line == 1
 | |
| 
 | |
|         error_codes = {error.code for error in errors}
 | |
|         assert "missing_required_parameter" in error_codes  # missing custom_id
 | |
|         assert "invalid_url" in error_codes  # URL mismatch
 | |
| 
 | |
|     async def test_validate_input_invalid_request_format(self, provider):
 | |
|         """Test _validate_input when file contains non-object JSON (array, string, number)."""
 | |
|         provider.files_api.openai_retrieve_file = AsyncMock()
 | |
|         mock_response = MagicMock()
 | |
|         mock_response.body = b'["not", "a", "request", "object"]'
 | |
|         provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
 | |
| 
 | |
|         batch = BatchObject(
 | |
|             id="batch_test",
 | |
|             object="batch",
 | |
|             endpoint="/v1/chat/completions",
 | |
|             input_file_id="invalid_format_file",
 | |
|             completion_window="24h",
 | |
|             status="validating",
 | |
|             created_at=1234567890,
 | |
|         )
 | |
| 
 | |
|         errors, requests = await provider._validate_input(batch)
 | |
| 
 | |
|         assert len(errors) == 1
 | |
|         assert len(requests) == 0
 | |
| 
 | |
|         assert errors[0].code == "invalid_request"
 | |
|         assert errors[0].line == 1
 | |
|         assert errors[0].message == "Each line must be a JSON dictionary object"
 | |
| 
 | |
|     @pytest.mark.parametrize(
 | |
|         "param_name,param_path,invalid_value,error_message",
 | |
|         [
 | |
|             ("custom_id", "custom_id", 12345, "Custom_id must be a string"),
 | |
|             ("url", "url", 123, "URL must be a string"),
 | |
|             ("method", "method", ["POST"], "Method must be a string"),
 | |
|             ("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"),
 | |
|             ("model", "body.model", 123, "Model must be a string"),
 | |
|             ("messages", "body.messages", "invalid messages format", "Messages must be an array"),
 | |
|         ],
 | |
|     )
 | |
|     async def test_validate_input_invalid_parameter_types(
 | |
|         self, provider, param_name, param_path, invalid_value, error_message
 | |
|     ):
 | |
|         """Test _validate_input when file contains request with parameters that have invalid types."""
 | |
|         provider.files_api.openai_retrieve_file = AsyncMock()
 | |
|         mock_response = MagicMock()
 | |
| 
 | |
|         base_request = {
 | |
|             "custom_id": "req-1",
 | |
|             "method": "POST",
 | |
|             "url": "/v1/chat/completions",
 | |
|             "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
 | |
|         }
 | |
| 
 | |
|         # Override the specific parameter with invalid value
 | |
|         if "." in param_path:
 | |
|             top_level, nested_param = param_path.split(".", 1)
 | |
|             base_request[top_level][nested_param] = invalid_value
 | |
|         else:
 | |
|             base_request[param_name] = invalid_value
 | |
| 
 | |
|         mock_response.body = json.dumps(base_request).encode()
 | |
|         provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
 | |
| 
 | |
|         batch = BatchObject(
 | |
|             id="batch_test",
 | |
|             object="batch",
 | |
|             endpoint="/v1/chat/completions",
 | |
|             input_file_id=f"invalid_{param_name}_type_file",
 | |
|             completion_window="24h",
 | |
|             status="validating",
 | |
|             created_at=1234567890,
 | |
|         )
 | |
| 
 | |
|         errors, requests = await provider._validate_input(batch)
 | |
| 
 | |
|         assert len(errors) == 1
 | |
|         assert len(requests) == 0
 | |
| 
 | |
|         assert errors[0].code == "invalid_request"
 | |
|         assert errors[0].line == 1
 | |
|         assert errors[0].message == error_message
 | |
|         assert errors[0].param == param_path
 | |
| 
 | |
|     async def test_max_concurrent_batches(self, provider):
 | |
|         """Test max_concurrent_batches configuration and concurrency control."""
 | |
|         import asyncio
 | |
| 
 | |
|         provider._batch_semaphore = asyncio.Semaphore(2)
 | |
| 
 | |
|         provider.process_batches = True  # enable because we're testing background processing
 | |
| 
 | |
|         active_batches = 0
 | |
| 
 | |
|         async def add_and_wait(batch_id: str):
 | |
|             nonlocal active_batches
 | |
|             active_batches += 1
 | |
|             await asyncio.sleep(float("inf"))
 | |
| 
 | |
|         # the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl,
 | |
|         # so we can replace _process_batch_impl with our mock to control concurrency
 | |
|         provider._process_batch_impl = add_and_wait
 | |
| 
 | |
|         for _ in range(3):
 | |
|             await provider.create_batch(
 | |
|                 input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h"
 | |
|             )
 | |
| 
 | |
|         await asyncio.sleep(0.042)  # let tasks start
 | |
| 
 | |
|         assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
 | |
| 
 | |
|     async def test_create_batch_embeddings_endpoint(self, provider):
 | |
|         """Test that batch creation succeeds with embeddings endpoint."""
 | |
|         batch = await provider.create_batch(
 | |
|             input_file_id="file_123",
 | |
|             endpoint="/v1/embeddings",
 | |
|             completion_window="24h",
 | |
|         )
 | |
|         assert batch.endpoint == "/v1/embeddings"
 |