diff --git a/tests/unit/providers/batches/test_reference.py b/tests/unit/providers/batches/test_reference.py index 32d59234d..bff286fc7 100644 --- a/tests/unit/providers/batches/test_reference.py +++ b/tests/unit/providers/batches/test_reference.py @@ -58,8 +58,15 @@ import json from unittest.mock import AsyncMock, MagicMock import pytest +from pydantic import ValidationError from llama_stack_api import BatchObject, ConflictError, ResourceNotFoundError +from llama_stack_api.batches.models import ( + CancelBatchRequest, + CreateBatchRequest, + ListBatchesRequest, + RetrieveBatchRequest, +) class TestReferenceBatchesImpl: @@ -169,7 +176,7 @@ class TestReferenceBatchesImpl: async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data): """Test successful batch creation and retrieval.""" - created_batch = await provider.create_batch(**sample_batch_data) + created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"]) @@ -184,7 +191,7 @@ class TestReferenceBatchesImpl: assert isinstance(created_batch.created_at, int) assert created_batch.created_at > 0 - retrieved_batch = await provider.retrieve_batch(created_batch.id) + retrieved_batch = await provider.retrieve_batch(RetrieveBatchRequest(batch_id=created_batch.id)) self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"]) @@ -197,17 +204,15 @@ class TestReferenceBatchesImpl: async def test_create_batch_without_metadata(self, provider): """Test batch creation without optional metadata.""" batch = await provider.create_batch( - input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h" + CreateBatchRequest(input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h") ) assert batch.metadata is None async def test_create_batch_completion_window(self, provider): """Test batch creation with invalid completion window.""" - with pytest.raises(ValueError, match="Invalid completion_window"): - await provider.create_batch( - input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now" - ) + with pytest.raises(ValidationError, match="completion_window"): + CreateBatchRequest(input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now") @pytest.mark.parametrize( "endpoint", @@ -219,37 +224,43 @@ class TestReferenceBatchesImpl: async def test_create_batch_invalid_endpoints(self, provider, endpoint): """Test batch creation with various invalid endpoints.""" with pytest.raises(ValueError, match="Invalid endpoint"): - await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h") + await provider.create_batch( + CreateBatchRequest(input_file_id="file_123", endpoint=endpoint, completion_window="24h") + ) async def test_create_batch_invalid_metadata(self, provider): """Test that batch creation fails with invalid metadata.""" with pytest.raises(ValueError, match="should be a valid string"): await provider.create_batch( - input_file_id="file_123", - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={123: "invalid_key"}, # Non-string key + CreateBatchRequest( + input_file_id="file_123", + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={123: "invalid_key"}, # Non-string key + ) ) with pytest.raises(ValueError, match="should be a valid string"): await provider.create_batch( - input_file_id="file_123", - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={"valid_key": 456}, # Non-string value + CreateBatchRequest( + input_file_id="file_123", + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={"valid_key": 456}, # Non-string value + ) ) async def test_retrieve_batch_not_found(self, provider): """Test error when retrieving non-existent batch.""" with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"): - await provider.retrieve_batch("nonexistent_batch") + await provider.retrieve_batch(RetrieveBatchRequest(batch_id="nonexistent_batch")) async def test_cancel_batch_success(self, provider, sample_batch_data): """Test successful batch cancellation.""" - created_batch = await provider.create_batch(**sample_batch_data) + created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) assert created_batch.status == "validating" - cancelled_batch = await provider.cancel_batch(created_batch.id) + cancelled_batch = await provider.cancel_batch(CancelBatchRequest(batch_id=created_batch.id)) assert cancelled_batch.id == created_batch.id assert cancelled_batch.status in ["cancelling", "cancelled"] @@ -260,22 +271,22 @@ class TestReferenceBatchesImpl: async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status): """Test error when cancelling batch in final states.""" provider.process_batches = False - created_batch = await provider.create_batch(**sample_batch_data) + created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) # directly update status in kvstore await provider._update_batch(created_batch.id, status=status) with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"): - await provider.cancel_batch(created_batch.id) + await provider.cancel_batch(CancelBatchRequest(batch_id=created_batch.id)) async def test_cancel_batch_not_found(self, provider): """Test error when cancelling non-existent batch.""" with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"): - await provider.cancel_batch("nonexistent_batch") + await provider.cancel_batch(CancelBatchRequest(batch_id="nonexistent_batch")) async def test_list_batches_empty(self, provider): """Test listing batches when none exist.""" - response = await provider.list_batches() + response = await provider.list_batches(ListBatchesRequest()) assert response.object == "list" assert response.data == [] @@ -285,9 +296,9 @@ class TestReferenceBatchesImpl: async def test_list_batches_single_batch(self, provider, sample_batch_data): """Test listing batches with single batch.""" - created_batch = await provider.create_batch(**sample_batch_data) + created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) - response = await provider.list_batches() + response = await provider.list_batches(ListBatchesRequest()) assert len(response.data) == 1 self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"]) @@ -300,12 +311,12 @@ class TestReferenceBatchesImpl: """Test listing multiple batches.""" batches = [ await provider.create_batch( - input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" + CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h") ) for i in range(3) ] - response = await provider.list_batches() + response = await provider.list_batches(ListBatchesRequest()) assert len(response.data) == 3 @@ -321,12 +332,12 @@ class TestReferenceBatchesImpl: """Test listing batches with limit parameter.""" batches = [ await provider.create_batch( - input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" + CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h") ) for i in range(3) ] - response = await provider.list_batches(limit=2) + response = await provider.list_batches(ListBatchesRequest(limit=2)) assert len(response.data) == 2 assert response.has_more is True @@ -340,36 +351,36 @@ class TestReferenceBatchesImpl: """Test listing batches with pagination using 'after' parameter.""" for i in range(3): await provider.create_batch( - input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" + CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h") ) # Get first page - first_page = await provider.list_batches(limit=1) + first_page = await provider.list_batches(ListBatchesRequest(limit=1)) assert len(first_page.data) == 1 assert first_page.has_more is True # Get second page using 'after' - second_page = await provider.list_batches(limit=1, after=first_page.data[0].id) + second_page = await provider.list_batches(ListBatchesRequest(limit=1, after=first_page.data[0].id)) assert len(second_page.data) == 1 assert second_page.data[0].id != first_page.data[0].id # Verify we got the next batch in order - all_batches = await provider.list_batches() + all_batches = await provider.list_batches(ListBatchesRequest()) expected_second_batch_id = all_batches.data[1].id assert second_page.data[0].id == expected_second_batch_id async def test_list_batches_invalid_after(self, provider, sample_batch_data): """Test listing batches with invalid 'after' parameter.""" - await provider.create_batch(**sample_batch_data) + await provider.create_batch(CreateBatchRequest(**sample_batch_data)) - response = await provider.list_batches(after="nonexistent_batch") + response = await provider.list_batches(ListBatchesRequest(after="nonexistent_batch")) # Should return all batches (no filtering when 'after' batch not found) assert len(response.data) == 1 async def test_kvstore_persistence(self, provider, sample_batch_data): """Test that batches are properly persisted in kvstore.""" - batch = await provider.create_batch(**sample_batch_data) + batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) stored_data = await provider.kvstore.get(f"batch:{batch.id}") assert stored_data is not None @@ -757,7 +768,7 @@ class TestReferenceBatchesImpl: for _ in range(3): await provider.create_batch( - input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h" + CreateBatchRequest(input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h") ) await asyncio.sleep(0.042) # let tasks start @@ -767,8 +778,10 @@ class TestReferenceBatchesImpl: async def test_create_batch_embeddings_endpoint(self, provider): """Test that batch creation succeeds with embeddings endpoint.""" batch = await provider.create_batch( - input_file_id="file_123", - endpoint="/v1/embeddings", - completion_window="24h", + CreateBatchRequest( + input_file_id="file_123", + endpoint="/v1/embeddings", + completion_window="24h", + ) ) assert batch.endpoint == "/v1/embeddings" diff --git a/tests/unit/providers/batches/test_reference_idempotency.py b/tests/unit/providers/batches/test_reference_idempotency.py index acb7ca01c..0ac73841e 100644 --- a/tests/unit/providers/batches/test_reference_idempotency.py +++ b/tests/unit/providers/batches/test_reference_idempotency.py @@ -45,6 +45,7 @@ import asyncio import pytest from llama_stack_api import ConflictError +from llama_stack_api.batches.models import CreateBatchRequest, RetrieveBatchRequest class TestReferenceBatchesIdempotency: @@ -56,18 +57,22 @@ class TestReferenceBatchesIdempotency: del sample_batch_data["metadata"] batch1 = await provider.create_batch( - **sample_batch_data, - metadata={"test": "value1", "other": "value2"}, - idempotency_key="unique-token-1", + CreateBatchRequest( + **sample_batch_data, + metadata={"test": "value1", "other": "value2"}, + idempotency_key="unique-token-1", + ) ) # sleep for 1 second to allow created_at timestamps to be different await asyncio.sleep(1) batch2 = await provider.create_batch( - **sample_batch_data, - metadata={"other": "value2", "test": "value1"}, # Different order - idempotency_key="unique-token-1", + CreateBatchRequest( + **sample_batch_data, + metadata={"other": "value2", "test": "value1"}, # Different order + idempotency_key="unique-token-1", + ) ) assert batch1.id == batch2.id @@ -77,23 +82,17 @@ class TestReferenceBatchesIdempotency: async def test_different_idempotency_keys_create_different_batches(self, provider, sample_batch_data): """Test that different idempotency keys create different batches even with same params.""" - batch1 = await provider.create_batch( - **sample_batch_data, - idempotency_key="token-A", - ) + batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data, idempotency_key="token-A")) - batch2 = await provider.create_batch( - **sample_batch_data, - idempotency_key="token-B", - ) + batch2 = await provider.create_batch(CreateBatchRequest(**sample_batch_data, idempotency_key="token-B")) assert batch1.id != batch2.id async def test_non_idempotent_behavior_without_key(self, provider, sample_batch_data): """Test that batches without idempotency key create unique batches even with identical parameters.""" - batch1 = await provider.create_batch(**sample_batch_data) + batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) - batch2 = await provider.create_batch(**sample_batch_data) + batch2 = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) assert batch1.id != batch2.id assert batch1.input_file_id == batch2.input_file_id @@ -117,12 +116,12 @@ class TestReferenceBatchesIdempotency: sample_batch_data[param_name] = first_value - batch1 = await provider.create_batch(**sample_batch_data) + batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) with pytest.raises(ConflictError, match="Idempotency key.*was previously used with different parameters"): sample_batch_data[param_name] = second_value - await provider.create_batch(**sample_batch_data) + await provider.create_batch(CreateBatchRequest(**sample_batch_data)) - retrieved_batch = await provider.retrieve_batch(batch1.id) + retrieved_batch = await provider.retrieve_batch(RetrieveBatchRequest(batch_id=batch1.id)) assert retrieved_batch.id == batch1.id assert getattr(retrieved_batch, param_name) == first_value