mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
chore: update unit test to use previously created Class
Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
9595619b9f
commit
f62c6044b3
2 changed files with 71 additions and 59 deletions
|
|
@ -58,8 +58,15 @@ import json
|
|||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from llama_stack_api import BatchObject, ConflictError, ResourceNotFoundError
|
||||
from llama_stack_api.batches.models import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
ListBatchesRequest,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
|
||||
|
||||
class TestReferenceBatchesImpl:
|
||||
|
|
@ -169,7 +176,7 @@ class TestReferenceBatchesImpl:
|
|||
|
||||
async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
|
||||
"""Test successful batch creation and retrieval."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
|
||||
|
||||
|
|
@ -184,7 +191,7 @@ class TestReferenceBatchesImpl:
|
|||
assert isinstance(created_batch.created_at, int)
|
||||
assert created_batch.created_at > 0
|
||||
|
||||
retrieved_batch = await provider.retrieve_batch(created_batch.id)
|
||||
retrieved_batch = await provider.retrieve_batch(RetrieveBatchRequest(batch_id=created_batch.id))
|
||||
|
||||
self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
|
||||
|
||||
|
|
@ -197,17 +204,15 @@ class TestReferenceBatchesImpl:
|
|||
async def test_create_batch_without_metadata(self, provider):
|
||||
"""Test batch creation without optional metadata."""
|
||||
batch = await provider.create_batch(
|
||||
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
CreateBatchRequest(input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h")
|
||||
)
|
||||
|
||||
assert batch.metadata is None
|
||||
|
||||
async def test_create_batch_completion_window(self, provider):
|
||||
"""Test batch creation with invalid completion window."""
|
||||
with pytest.raises(ValueError, match="Invalid completion_window"):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
|
||||
)
|
||||
with pytest.raises(ValidationError, match="completion_window"):
|
||||
CreateBatchRequest(input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
|
|
@ -219,37 +224,43 @@ class TestReferenceBatchesImpl:
|
|||
async def test_create_batch_invalid_endpoints(self, provider, endpoint):
|
||||
"""Test batch creation with various invalid endpoints."""
|
||||
with pytest.raises(ValueError, match="Invalid endpoint"):
|
||||
await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
|
||||
await provider.create_batch(
|
||||
CreateBatchRequest(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
|
||||
)
|
||||
|
||||
async def test_create_batch_invalid_metadata(self, provider):
|
||||
"""Test that batch creation fails with invalid metadata."""
|
||||
with pytest.raises(ValueError, match="should be a valid string"):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={123: "invalid_key"}, # Non-string key
|
||||
CreateBatchRequest(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={123: "invalid_key"}, # Non-string key
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="should be a valid string"):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={"valid_key": 456}, # Non-string value
|
||||
CreateBatchRequest(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={"valid_key": 456}, # Non-string value
|
||||
)
|
||||
)
|
||||
|
||||
async def test_retrieve_batch_not_found(self, provider):
|
||||
"""Test error when retrieving non-existent batch."""
|
||||
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
||||
await provider.retrieve_batch("nonexistent_batch")
|
||||
await provider.retrieve_batch(RetrieveBatchRequest(batch_id="nonexistent_batch"))
|
||||
|
||||
async def test_cancel_batch_success(self, provider, sample_batch_data):
|
||||
"""Test successful batch cancellation."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
assert created_batch.status == "validating"
|
||||
|
||||
cancelled_batch = await provider.cancel_batch(created_batch.id)
|
||||
cancelled_batch = await provider.cancel_batch(CancelBatchRequest(batch_id=created_batch.id))
|
||||
|
||||
assert cancelled_batch.id == created_batch.id
|
||||
assert cancelled_batch.status in ["cancelling", "cancelled"]
|
||||
|
|
@ -260,22 +271,22 @@ class TestReferenceBatchesImpl:
|
|||
async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
|
||||
"""Test error when cancelling batch in final states."""
|
||||
provider.process_batches = False
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
# directly update status in kvstore
|
||||
await provider._update_batch(created_batch.id, status=status)
|
||||
|
||||
with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
|
||||
await provider.cancel_batch(created_batch.id)
|
||||
await provider.cancel_batch(CancelBatchRequest(batch_id=created_batch.id))
|
||||
|
||||
async def test_cancel_batch_not_found(self, provider):
|
||||
"""Test error when cancelling non-existent batch."""
|
||||
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
||||
await provider.cancel_batch("nonexistent_batch")
|
||||
await provider.cancel_batch(CancelBatchRequest(batch_id="nonexistent_batch"))
|
||||
|
||||
async def test_list_batches_empty(self, provider):
|
||||
"""Test listing batches when none exist."""
|
||||
response = await provider.list_batches()
|
||||
response = await provider.list_batches(ListBatchesRequest())
|
||||
|
||||
assert response.object == "list"
|
||||
assert response.data == []
|
||||
|
|
@ -285,9 +296,9 @@ class TestReferenceBatchesImpl:
|
|||
|
||||
async def test_list_batches_single_batch(self, provider, sample_batch_data):
|
||||
"""Test listing batches with single batch."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
response = await provider.list_batches()
|
||||
response = await provider.list_batches(ListBatchesRequest())
|
||||
|
||||
assert len(response.data) == 1
|
||||
self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
|
||||
|
|
@ -300,12 +311,12 @@ class TestReferenceBatchesImpl:
|
|||
"""Test listing multiple batches."""
|
||||
batches = [
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h")
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
response = await provider.list_batches()
|
||||
response = await provider.list_batches(ListBatchesRequest())
|
||||
|
||||
assert len(response.data) == 3
|
||||
|
||||
|
|
@ -321,12 +332,12 @@ class TestReferenceBatchesImpl:
|
|||
"""Test listing batches with limit parameter."""
|
||||
batches = [
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h")
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
response = await provider.list_batches(limit=2)
|
||||
response = await provider.list_batches(ListBatchesRequest(limit=2))
|
||||
|
||||
assert len(response.data) == 2
|
||||
assert response.has_more is True
|
||||
|
|
@ -340,36 +351,36 @@ class TestReferenceBatchesImpl:
|
|||
"""Test listing batches with pagination using 'after' parameter."""
|
||||
for i in range(3):
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h")
|
||||
)
|
||||
|
||||
# Get first page
|
||||
first_page = await provider.list_batches(limit=1)
|
||||
first_page = await provider.list_batches(ListBatchesRequest(limit=1))
|
||||
assert len(first_page.data) == 1
|
||||
assert first_page.has_more is True
|
||||
|
||||
# Get second page using 'after'
|
||||
second_page = await provider.list_batches(limit=1, after=first_page.data[0].id)
|
||||
second_page = await provider.list_batches(ListBatchesRequest(limit=1, after=first_page.data[0].id))
|
||||
assert len(second_page.data) == 1
|
||||
assert second_page.data[0].id != first_page.data[0].id
|
||||
|
||||
# Verify we got the next batch in order
|
||||
all_batches = await provider.list_batches()
|
||||
all_batches = await provider.list_batches(ListBatchesRequest())
|
||||
expected_second_batch_id = all_batches.data[1].id
|
||||
assert second_page.data[0].id == expected_second_batch_id
|
||||
|
||||
async def test_list_batches_invalid_after(self, provider, sample_batch_data):
|
||||
"""Test listing batches with invalid 'after' parameter."""
|
||||
await provider.create_batch(**sample_batch_data)
|
||||
await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
response = await provider.list_batches(after="nonexistent_batch")
|
||||
response = await provider.list_batches(ListBatchesRequest(after="nonexistent_batch"))
|
||||
|
||||
# Should return all batches (no filtering when 'after' batch not found)
|
||||
assert len(response.data) == 1
|
||||
|
||||
async def test_kvstore_persistence(self, provider, sample_batch_data):
|
||||
"""Test that batches are properly persisted in kvstore."""
|
||||
batch = await provider.create_batch(**sample_batch_data)
|
||||
batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
stored_data = await provider.kvstore.get(f"batch:{batch.id}")
|
||||
assert stored_data is not None
|
||||
|
|
@ -757,7 +768,7 @@ class TestReferenceBatchesImpl:
|
|||
|
||||
for _ in range(3):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
CreateBatchRequest(input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h")
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.042) # let tasks start
|
||||
|
|
@ -767,8 +778,10 @@ class TestReferenceBatchesImpl:
|
|||
async def test_create_batch_embeddings_endpoint(self, provider):
|
||||
"""Test that batch creation succeeds with embeddings endpoint."""
|
||||
batch = await provider.create_batch(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/embeddings",
|
||||
completion_window="24h",
|
||||
CreateBatchRequest(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/embeddings",
|
||||
completion_window="24h",
|
||||
)
|
||||
)
|
||||
assert batch.endpoint == "/v1/embeddings"
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ import asyncio
|
|||
import pytest
|
||||
|
||||
from llama_stack_api import ConflictError
|
||||
from llama_stack_api.batches.models import CreateBatchRequest, RetrieveBatchRequest
|
||||
|
||||
|
||||
class TestReferenceBatchesIdempotency:
|
||||
|
|
@ -56,18 +57,22 @@ class TestReferenceBatchesIdempotency:
|
|||
del sample_batch_data["metadata"]
|
||||
|
||||
batch1 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
metadata={"test": "value1", "other": "value2"},
|
||||
idempotency_key="unique-token-1",
|
||||
CreateBatchRequest(
|
||||
**sample_batch_data,
|
||||
metadata={"test": "value1", "other": "value2"},
|
||||
idempotency_key="unique-token-1",
|
||||
)
|
||||
)
|
||||
|
||||
# sleep for 1 second to allow created_at timestamps to be different
|
||||
await asyncio.sleep(1)
|
||||
|
||||
batch2 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
metadata={"other": "value2", "test": "value1"}, # Different order
|
||||
idempotency_key="unique-token-1",
|
||||
CreateBatchRequest(
|
||||
**sample_batch_data,
|
||||
metadata={"other": "value2", "test": "value1"}, # Different order
|
||||
idempotency_key="unique-token-1",
|
||||
)
|
||||
)
|
||||
|
||||
assert batch1.id == batch2.id
|
||||
|
|
@ -77,23 +82,17 @@ class TestReferenceBatchesIdempotency:
|
|||
|
||||
async def test_different_idempotency_keys_create_different_batches(self, provider, sample_batch_data):
|
||||
"""Test that different idempotency keys create different batches even with same params."""
|
||||
batch1 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
idempotency_key="token-A",
|
||||
)
|
||||
batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data, idempotency_key="token-A"))
|
||||
|
||||
batch2 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
idempotency_key="token-B",
|
||||
)
|
||||
batch2 = await provider.create_batch(CreateBatchRequest(**sample_batch_data, idempotency_key="token-B"))
|
||||
|
||||
assert batch1.id != batch2.id
|
||||
|
||||
async def test_non_idempotent_behavior_without_key(self, provider, sample_batch_data):
|
||||
"""Test that batches without idempotency key create unique batches even with identical parameters."""
|
||||
batch1 = await provider.create_batch(**sample_batch_data)
|
||||
batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
batch2 = await provider.create_batch(**sample_batch_data)
|
||||
batch2 = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
assert batch1.id != batch2.id
|
||||
assert batch1.input_file_id == batch2.input_file_id
|
||||
|
|
@ -117,12 +116,12 @@ class TestReferenceBatchesIdempotency:
|
|||
|
||||
sample_batch_data[param_name] = first_value
|
||||
|
||||
batch1 = await provider.create_batch(**sample_batch_data)
|
||||
batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
with pytest.raises(ConflictError, match="Idempotency key.*was previously used with different parameters"):
|
||||
sample_batch_data[param_name] = second_value
|
||||
await provider.create_batch(**sample_batch_data)
|
||||
await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||
|
||||
retrieved_batch = await provider.retrieve_batch(batch1.id)
|
||||
retrieved_batch = await provider.retrieve_batch(RetrieveBatchRequest(batch_id=batch1.id))
|
||||
assert retrieved_batch.id == batch1.id
|
||||
assert getattr(retrieved_batch, param_name) == first_value
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue