chore: update unit test to use previously created Class

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-20 16:40:49 +01:00
parent 9595619b9f
commit f62c6044b3
No known key found for this signature in database
2 changed files with 71 additions and 59 deletions

View file

@ -58,8 +58,15 @@ import json
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from pydantic import ValidationError
from llama_stack_api import BatchObject, ConflictError, ResourceNotFoundError from llama_stack_api import BatchObject, ConflictError, ResourceNotFoundError
from llama_stack_api.batches.models import (
CancelBatchRequest,
CreateBatchRequest,
ListBatchesRequest,
RetrieveBatchRequest,
)
class TestReferenceBatchesImpl: class TestReferenceBatchesImpl:
@ -169,7 +176,7 @@ class TestReferenceBatchesImpl:
async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data): async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
"""Test successful batch creation and retrieval.""" """Test successful batch creation and retrieval."""
created_batch = await provider.create_batch(**sample_batch_data) created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"]) self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
@ -184,7 +191,7 @@ class TestReferenceBatchesImpl:
assert isinstance(created_batch.created_at, int) assert isinstance(created_batch.created_at, int)
assert created_batch.created_at > 0 assert created_batch.created_at > 0
retrieved_batch = await provider.retrieve_batch(created_batch.id) retrieved_batch = await provider.retrieve_batch(RetrieveBatchRequest(batch_id=created_batch.id))
self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"]) self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
@ -197,17 +204,15 @@ class TestReferenceBatchesImpl:
async def test_create_batch_without_metadata(self, provider): async def test_create_batch_without_metadata(self, provider):
"""Test batch creation without optional metadata.""" """Test batch creation without optional metadata."""
batch = await provider.create_batch( batch = await provider.create_batch(
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h" CreateBatchRequest(input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h")
) )
assert batch.metadata is None assert batch.metadata is None
async def test_create_batch_completion_window(self, provider): async def test_create_batch_completion_window(self, provider):
"""Test batch creation with invalid completion window.""" """Test batch creation with invalid completion window."""
with pytest.raises(ValueError, match="Invalid completion_window"): with pytest.raises(ValidationError, match="completion_window"):
await provider.create_batch( CreateBatchRequest(input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now")
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"endpoint", "endpoint",
@ -219,37 +224,43 @@ class TestReferenceBatchesImpl:
async def test_create_batch_invalid_endpoints(self, provider, endpoint): async def test_create_batch_invalid_endpoints(self, provider, endpoint):
"""Test batch creation with various invalid endpoints.""" """Test batch creation with various invalid endpoints."""
with pytest.raises(ValueError, match="Invalid endpoint"): with pytest.raises(ValueError, match="Invalid endpoint"):
await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h") await provider.create_batch(
CreateBatchRequest(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
)
async def test_create_batch_invalid_metadata(self, provider): async def test_create_batch_invalid_metadata(self, provider):
"""Test that batch creation fails with invalid metadata.""" """Test that batch creation fails with invalid metadata."""
with pytest.raises(ValueError, match="should be a valid string"): with pytest.raises(ValueError, match="should be a valid string"):
await provider.create_batch( await provider.create_batch(
input_file_id="file_123", CreateBatchRequest(
endpoint="/v1/chat/completions", input_file_id="file_123",
completion_window="24h", endpoint="/v1/chat/completions",
metadata={123: "invalid_key"}, # Non-string key completion_window="24h",
metadata={123: "invalid_key"}, # Non-string key
)
) )
with pytest.raises(ValueError, match="should be a valid string"): with pytest.raises(ValueError, match="should be a valid string"):
await provider.create_batch( await provider.create_batch(
input_file_id="file_123", CreateBatchRequest(
endpoint="/v1/chat/completions", input_file_id="file_123",
completion_window="24h", endpoint="/v1/chat/completions",
metadata={"valid_key": 456}, # Non-string value completion_window="24h",
metadata={"valid_key": 456}, # Non-string value
)
) )
async def test_retrieve_batch_not_found(self, provider): async def test_retrieve_batch_not_found(self, provider):
"""Test error when retrieving non-existent batch.""" """Test error when retrieving non-existent batch."""
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"): with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
await provider.retrieve_batch("nonexistent_batch") await provider.retrieve_batch(RetrieveBatchRequest(batch_id="nonexistent_batch"))
async def test_cancel_batch_success(self, provider, sample_batch_data): async def test_cancel_batch_success(self, provider, sample_batch_data):
"""Test successful batch cancellation.""" """Test successful batch cancellation."""
created_batch = await provider.create_batch(**sample_batch_data) created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
assert created_batch.status == "validating" assert created_batch.status == "validating"
cancelled_batch = await provider.cancel_batch(created_batch.id) cancelled_batch = await provider.cancel_batch(CancelBatchRequest(batch_id=created_batch.id))
assert cancelled_batch.id == created_batch.id assert cancelled_batch.id == created_batch.id
assert cancelled_batch.status in ["cancelling", "cancelled"] assert cancelled_batch.status in ["cancelling", "cancelled"]
@ -260,22 +271,22 @@ class TestReferenceBatchesImpl:
async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status): async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
"""Test error when cancelling batch in final states.""" """Test error when cancelling batch in final states."""
provider.process_batches = False provider.process_batches = False
created_batch = await provider.create_batch(**sample_batch_data) created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
# directly update status in kvstore # directly update status in kvstore
await provider._update_batch(created_batch.id, status=status) await provider._update_batch(created_batch.id, status=status)
with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"): with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
await provider.cancel_batch(created_batch.id) await provider.cancel_batch(CancelBatchRequest(batch_id=created_batch.id))
async def test_cancel_batch_not_found(self, provider): async def test_cancel_batch_not_found(self, provider):
"""Test error when cancelling non-existent batch.""" """Test error when cancelling non-existent batch."""
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"): with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
await provider.cancel_batch("nonexistent_batch") await provider.cancel_batch(CancelBatchRequest(batch_id="nonexistent_batch"))
async def test_list_batches_empty(self, provider): async def test_list_batches_empty(self, provider):
"""Test listing batches when none exist.""" """Test listing batches when none exist."""
response = await provider.list_batches() response = await provider.list_batches(ListBatchesRequest())
assert response.object == "list" assert response.object == "list"
assert response.data == [] assert response.data == []
@ -285,9 +296,9 @@ class TestReferenceBatchesImpl:
async def test_list_batches_single_batch(self, provider, sample_batch_data): async def test_list_batches_single_batch(self, provider, sample_batch_data):
"""Test listing batches with single batch.""" """Test listing batches with single batch."""
created_batch = await provider.create_batch(**sample_batch_data) created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
response = await provider.list_batches() response = await provider.list_batches(ListBatchesRequest())
assert len(response.data) == 1 assert len(response.data) == 1
self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"]) self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
@ -300,12 +311,12 @@ class TestReferenceBatchesImpl:
"""Test listing multiple batches.""" """Test listing multiple batches."""
batches = [ batches = [
await provider.create_batch( await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h")
) )
for i in range(3) for i in range(3)
] ]
response = await provider.list_batches() response = await provider.list_batches(ListBatchesRequest())
assert len(response.data) == 3 assert len(response.data) == 3
@ -321,12 +332,12 @@ class TestReferenceBatchesImpl:
"""Test listing batches with limit parameter.""" """Test listing batches with limit parameter."""
batches = [ batches = [
await provider.create_batch( await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h")
) )
for i in range(3) for i in range(3)
] ]
response = await provider.list_batches(limit=2) response = await provider.list_batches(ListBatchesRequest(limit=2))
assert len(response.data) == 2 assert len(response.data) == 2
assert response.has_more is True assert response.has_more is True
@ -340,36 +351,36 @@ class TestReferenceBatchesImpl:
"""Test listing batches with pagination using 'after' parameter.""" """Test listing batches with pagination using 'after' parameter."""
for i in range(3): for i in range(3):
await provider.create_batch( await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h")
) )
# Get first page # Get first page
first_page = await provider.list_batches(limit=1) first_page = await provider.list_batches(ListBatchesRequest(limit=1))
assert len(first_page.data) == 1 assert len(first_page.data) == 1
assert first_page.has_more is True assert first_page.has_more is True
# Get second page using 'after' # Get second page using 'after'
second_page = await provider.list_batches(limit=1, after=first_page.data[0].id) second_page = await provider.list_batches(ListBatchesRequest(limit=1, after=first_page.data[0].id))
assert len(second_page.data) == 1 assert len(second_page.data) == 1
assert second_page.data[0].id != first_page.data[0].id assert second_page.data[0].id != first_page.data[0].id
# Verify we got the next batch in order # Verify we got the next batch in order
all_batches = await provider.list_batches() all_batches = await provider.list_batches(ListBatchesRequest())
expected_second_batch_id = all_batches.data[1].id expected_second_batch_id = all_batches.data[1].id
assert second_page.data[0].id == expected_second_batch_id assert second_page.data[0].id == expected_second_batch_id
async def test_list_batches_invalid_after(self, provider, sample_batch_data): async def test_list_batches_invalid_after(self, provider, sample_batch_data):
"""Test listing batches with invalid 'after' parameter.""" """Test listing batches with invalid 'after' parameter."""
await provider.create_batch(**sample_batch_data) await provider.create_batch(CreateBatchRequest(**sample_batch_data))
response = await provider.list_batches(after="nonexistent_batch") response = await provider.list_batches(ListBatchesRequest(after="nonexistent_batch"))
# Should return all batches (no filtering when 'after' batch not found) # Should return all batches (no filtering when 'after' batch not found)
assert len(response.data) == 1 assert len(response.data) == 1
async def test_kvstore_persistence(self, provider, sample_batch_data): async def test_kvstore_persistence(self, provider, sample_batch_data):
"""Test that batches are properly persisted in kvstore.""" """Test that batches are properly persisted in kvstore."""
batch = await provider.create_batch(**sample_batch_data) batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
stored_data = await provider.kvstore.get(f"batch:{batch.id}") stored_data = await provider.kvstore.get(f"batch:{batch.id}")
assert stored_data is not None assert stored_data is not None
@ -757,7 +768,7 @@ class TestReferenceBatchesImpl:
for _ in range(3): for _ in range(3):
await provider.create_batch( await provider.create_batch(
input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h" CreateBatchRequest(input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h")
) )
await asyncio.sleep(0.042) # let tasks start await asyncio.sleep(0.042) # let tasks start
@ -767,8 +778,10 @@ class TestReferenceBatchesImpl:
async def test_create_batch_embeddings_endpoint(self, provider): async def test_create_batch_embeddings_endpoint(self, provider):
"""Test that batch creation succeeds with embeddings endpoint.""" """Test that batch creation succeeds with embeddings endpoint."""
batch = await provider.create_batch( batch = await provider.create_batch(
input_file_id="file_123", CreateBatchRequest(
endpoint="/v1/embeddings", input_file_id="file_123",
completion_window="24h", endpoint="/v1/embeddings",
completion_window="24h",
)
) )
assert batch.endpoint == "/v1/embeddings" assert batch.endpoint == "/v1/embeddings"

View file

@ -45,6 +45,7 @@ import asyncio
import pytest import pytest
from llama_stack_api import ConflictError from llama_stack_api import ConflictError
from llama_stack_api.batches.models import CreateBatchRequest, RetrieveBatchRequest
class TestReferenceBatchesIdempotency: class TestReferenceBatchesIdempotency:
@ -56,18 +57,22 @@ class TestReferenceBatchesIdempotency:
del sample_batch_data["metadata"] del sample_batch_data["metadata"]
batch1 = await provider.create_batch( batch1 = await provider.create_batch(
**sample_batch_data, CreateBatchRequest(
metadata={"test": "value1", "other": "value2"}, **sample_batch_data,
idempotency_key="unique-token-1", metadata={"test": "value1", "other": "value2"},
idempotency_key="unique-token-1",
)
) )
# sleep for 1 second to allow created_at timestamps to be different # sleep for 1 second to allow created_at timestamps to be different
await asyncio.sleep(1) await asyncio.sleep(1)
batch2 = await provider.create_batch( batch2 = await provider.create_batch(
**sample_batch_data, CreateBatchRequest(
metadata={"other": "value2", "test": "value1"}, # Different order **sample_batch_data,
idempotency_key="unique-token-1", metadata={"other": "value2", "test": "value1"}, # Different order
idempotency_key="unique-token-1",
)
) )
assert batch1.id == batch2.id assert batch1.id == batch2.id
@ -77,23 +82,17 @@ class TestReferenceBatchesIdempotency:
async def test_different_idempotency_keys_create_different_batches(self, provider, sample_batch_data): async def test_different_idempotency_keys_create_different_batches(self, provider, sample_batch_data):
"""Test that different idempotency keys create different batches even with same params.""" """Test that different idempotency keys create different batches even with same params."""
batch1 = await provider.create_batch( batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data, idempotency_key="token-A"))
**sample_batch_data,
idempotency_key="token-A",
)
batch2 = await provider.create_batch( batch2 = await provider.create_batch(CreateBatchRequest(**sample_batch_data, idempotency_key="token-B"))
**sample_batch_data,
idempotency_key="token-B",
)
assert batch1.id != batch2.id assert batch1.id != batch2.id
async def test_non_idempotent_behavior_without_key(self, provider, sample_batch_data): async def test_non_idempotent_behavior_without_key(self, provider, sample_batch_data):
"""Test that batches without idempotency key create unique batches even with identical parameters.""" """Test that batches without idempotency key create unique batches even with identical parameters."""
batch1 = await provider.create_batch(**sample_batch_data) batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
batch2 = await provider.create_batch(**sample_batch_data) batch2 = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
assert batch1.id != batch2.id assert batch1.id != batch2.id
assert batch1.input_file_id == batch2.input_file_id assert batch1.input_file_id == batch2.input_file_id
@ -117,12 +116,12 @@ class TestReferenceBatchesIdempotency:
sample_batch_data[param_name] = first_value sample_batch_data[param_name] = first_value
batch1 = await provider.create_batch(**sample_batch_data) batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
with pytest.raises(ConflictError, match="Idempotency key.*was previously used with different parameters"): with pytest.raises(ConflictError, match="Idempotency key.*was previously used with different parameters"):
sample_batch_data[param_name] = second_value sample_batch_data[param_name] = second_value
await provider.create_batch(**sample_batch_data) await provider.create_batch(CreateBatchRequest(**sample_batch_data))
retrieved_batch = await provider.retrieve_batch(batch1.id) retrieved_batch = await provider.retrieve_batch(RetrieveBatchRequest(batch_id=batch1.id))
assert retrieved_batch.id == batch1.id assert retrieved_batch.id == batch1.id
assert getattr(retrieved_batch, param_name) == first_value assert getattr(retrieved_batch, param_name) == first_value