mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
chore: update unit test to use previously created Class
Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
9595619b9f
commit
f62c6044b3
2 changed files with 71 additions and 59 deletions
|
|
@ -58,8 +58,15 @@ import json
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from llama_stack_api import BatchObject, ConflictError, ResourceNotFoundError
|
from llama_stack_api import BatchObject, ConflictError, ResourceNotFoundError
|
||||||
|
from llama_stack_api.batches.models import (
|
||||||
|
CancelBatchRequest,
|
||||||
|
CreateBatchRequest,
|
||||||
|
ListBatchesRequest,
|
||||||
|
RetrieveBatchRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestReferenceBatchesImpl:
|
class TestReferenceBatchesImpl:
|
||||||
|
|
@ -169,7 +176,7 @@ class TestReferenceBatchesImpl:
|
||||||
|
|
||||||
async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
|
async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
|
||||||
"""Test successful batch creation and retrieval."""
|
"""Test successful batch creation and retrieval."""
|
||||||
created_batch = await provider.create_batch(**sample_batch_data)
|
created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||||
|
|
||||||
self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
|
self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
|
||||||
|
|
||||||
|
|
@ -184,7 +191,7 @@ class TestReferenceBatchesImpl:
|
||||||
assert isinstance(created_batch.created_at, int)
|
assert isinstance(created_batch.created_at, int)
|
||||||
assert created_batch.created_at > 0
|
assert created_batch.created_at > 0
|
||||||
|
|
||||||
retrieved_batch = await provider.retrieve_batch(created_batch.id)
|
retrieved_batch = await provider.retrieve_batch(RetrieveBatchRequest(batch_id=created_batch.id))
|
||||||
|
|
||||||
self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
|
self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
|
||||||
|
|
||||||
|
|
@ -197,17 +204,15 @@ class TestReferenceBatchesImpl:
|
||||||
async def test_create_batch_without_metadata(self, provider):
|
async def test_create_batch_without_metadata(self, provider):
|
||||||
"""Test batch creation without optional metadata."""
|
"""Test batch creation without optional metadata."""
|
||||||
batch = await provider.create_batch(
|
batch = await provider.create_batch(
|
||||||
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h"
|
CreateBatchRequest(input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h")
|
||||||
)
|
)
|
||||||
|
|
||||||
assert batch.metadata is None
|
assert batch.metadata is None
|
||||||
|
|
||||||
async def test_create_batch_completion_window(self, provider):
|
async def test_create_batch_completion_window(self, provider):
|
||||||
"""Test batch creation with invalid completion window."""
|
"""Test batch creation with invalid completion window."""
|
||||||
with pytest.raises(ValueError, match="Invalid completion_window"):
|
with pytest.raises(ValidationError, match="completion_window"):
|
||||||
await provider.create_batch(
|
CreateBatchRequest(input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now")
|
||||||
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"endpoint",
|
"endpoint",
|
||||||
|
|
@ -219,37 +224,43 @@ class TestReferenceBatchesImpl:
|
||||||
async def test_create_batch_invalid_endpoints(self, provider, endpoint):
|
async def test_create_batch_invalid_endpoints(self, provider, endpoint):
|
||||||
"""Test batch creation with various invalid endpoints."""
|
"""Test batch creation with various invalid endpoints."""
|
||||||
with pytest.raises(ValueError, match="Invalid endpoint"):
|
with pytest.raises(ValueError, match="Invalid endpoint"):
|
||||||
await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
|
await provider.create_batch(
|
||||||
|
CreateBatchRequest(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
|
||||||
|
)
|
||||||
|
|
||||||
async def test_create_batch_invalid_metadata(self, provider):
|
async def test_create_batch_invalid_metadata(self, provider):
|
||||||
"""Test that batch creation fails with invalid metadata."""
|
"""Test that batch creation fails with invalid metadata."""
|
||||||
with pytest.raises(ValueError, match="should be a valid string"):
|
with pytest.raises(ValueError, match="should be a valid string"):
|
||||||
await provider.create_batch(
|
await provider.create_batch(
|
||||||
|
CreateBatchRequest(
|
||||||
input_file_id="file_123",
|
input_file_id="file_123",
|
||||||
endpoint="/v1/chat/completions",
|
endpoint="/v1/chat/completions",
|
||||||
completion_window="24h",
|
completion_window="24h",
|
||||||
metadata={123: "invalid_key"}, # Non-string key
|
metadata={123: "invalid_key"}, # Non-string key
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="should be a valid string"):
|
with pytest.raises(ValueError, match="should be a valid string"):
|
||||||
await provider.create_batch(
|
await provider.create_batch(
|
||||||
|
CreateBatchRequest(
|
||||||
input_file_id="file_123",
|
input_file_id="file_123",
|
||||||
endpoint="/v1/chat/completions",
|
endpoint="/v1/chat/completions",
|
||||||
completion_window="24h",
|
completion_window="24h",
|
||||||
metadata={"valid_key": 456}, # Non-string value
|
metadata={"valid_key": 456}, # Non-string value
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def test_retrieve_batch_not_found(self, provider):
|
async def test_retrieve_batch_not_found(self, provider):
|
||||||
"""Test error when retrieving non-existent batch."""
|
"""Test error when retrieving non-existent batch."""
|
||||||
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
||||||
await provider.retrieve_batch("nonexistent_batch")
|
await provider.retrieve_batch(RetrieveBatchRequest(batch_id="nonexistent_batch"))
|
||||||
|
|
||||||
async def test_cancel_batch_success(self, provider, sample_batch_data):
|
async def test_cancel_batch_success(self, provider, sample_batch_data):
|
||||||
"""Test successful batch cancellation."""
|
"""Test successful batch cancellation."""
|
||||||
created_batch = await provider.create_batch(**sample_batch_data)
|
created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||||
assert created_batch.status == "validating"
|
assert created_batch.status == "validating"
|
||||||
|
|
||||||
cancelled_batch = await provider.cancel_batch(created_batch.id)
|
cancelled_batch = await provider.cancel_batch(CancelBatchRequest(batch_id=created_batch.id))
|
||||||
|
|
||||||
assert cancelled_batch.id == created_batch.id
|
assert cancelled_batch.id == created_batch.id
|
||||||
assert cancelled_batch.status in ["cancelling", "cancelled"]
|
assert cancelled_batch.status in ["cancelling", "cancelled"]
|
||||||
|
|
@ -260,22 +271,22 @@ class TestReferenceBatchesImpl:
|
||||||
async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
|
async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
|
||||||
"""Test error when cancelling batch in final states."""
|
"""Test error when cancelling batch in final states."""
|
||||||
provider.process_batches = False
|
provider.process_batches = False
|
||||||
created_batch = await provider.create_batch(**sample_batch_data)
|
created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||||
|
|
||||||
# directly update status in kvstore
|
# directly update status in kvstore
|
||||||
await provider._update_batch(created_batch.id, status=status)
|
await provider._update_batch(created_batch.id, status=status)
|
||||||
|
|
||||||
with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
|
with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
|
||||||
await provider.cancel_batch(created_batch.id)
|
await provider.cancel_batch(CancelBatchRequest(batch_id=created_batch.id))
|
||||||
|
|
||||||
async def test_cancel_batch_not_found(self, provider):
|
async def test_cancel_batch_not_found(self, provider):
|
||||||
"""Test error when cancelling non-existent batch."""
|
"""Test error when cancelling non-existent batch."""
|
||||||
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
||||||
await provider.cancel_batch("nonexistent_batch")
|
await provider.cancel_batch(CancelBatchRequest(batch_id="nonexistent_batch"))
|
||||||
|
|
||||||
async def test_list_batches_empty(self, provider):
|
async def test_list_batches_empty(self, provider):
|
||||||
"""Test listing batches when none exist."""
|
"""Test listing batches when none exist."""
|
||||||
response = await provider.list_batches()
|
response = await provider.list_batches(ListBatchesRequest())
|
||||||
|
|
||||||
assert response.object == "list"
|
assert response.object == "list"
|
||||||
assert response.data == []
|
assert response.data == []
|
||||||
|
|
@ -285,9 +296,9 @@ class TestReferenceBatchesImpl:
|
||||||
|
|
||||||
async def test_list_batches_single_batch(self, provider, sample_batch_data):
|
async def test_list_batches_single_batch(self, provider, sample_batch_data):
|
||||||
"""Test listing batches with single batch."""
|
"""Test listing batches with single batch."""
|
||||||
created_batch = await provider.create_batch(**sample_batch_data)
|
created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||||
|
|
||||||
response = await provider.list_batches()
|
response = await provider.list_batches(ListBatchesRequest())
|
||||||
|
|
||||||
assert len(response.data) == 1
|
assert len(response.data) == 1
|
||||||
self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
|
self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
|
||||||
|
|
@ -300,12 +311,12 @@ class TestReferenceBatchesImpl:
|
||||||
"""Test listing multiple batches."""
|
"""Test listing multiple batches."""
|
||||||
batches = [
|
batches = [
|
||||||
await provider.create_batch(
|
await provider.create_batch(
|
||||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h")
|
||||||
)
|
)
|
||||||
for i in range(3)
|
for i in range(3)
|
||||||
]
|
]
|
||||||
|
|
||||||
response = await provider.list_batches()
|
response = await provider.list_batches(ListBatchesRequest())
|
||||||
|
|
||||||
assert len(response.data) == 3
|
assert len(response.data) == 3
|
||||||
|
|
||||||
|
|
@ -321,12 +332,12 @@ class TestReferenceBatchesImpl:
|
||||||
"""Test listing batches with limit parameter."""
|
"""Test listing batches with limit parameter."""
|
||||||
batches = [
|
batches = [
|
||||||
await provider.create_batch(
|
await provider.create_batch(
|
||||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h")
|
||||||
)
|
)
|
||||||
for i in range(3)
|
for i in range(3)
|
||||||
]
|
]
|
||||||
|
|
||||||
response = await provider.list_batches(limit=2)
|
response = await provider.list_batches(ListBatchesRequest(limit=2))
|
||||||
|
|
||||||
assert len(response.data) == 2
|
assert len(response.data) == 2
|
||||||
assert response.has_more is True
|
assert response.has_more is True
|
||||||
|
|
@ -340,36 +351,36 @@ class TestReferenceBatchesImpl:
|
||||||
"""Test listing batches with pagination using 'after' parameter."""
|
"""Test listing batches with pagination using 'after' parameter."""
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
await provider.create_batch(
|
await provider.create_batch(
|
||||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get first page
|
# Get first page
|
||||||
first_page = await provider.list_batches(limit=1)
|
first_page = await provider.list_batches(ListBatchesRequest(limit=1))
|
||||||
assert len(first_page.data) == 1
|
assert len(first_page.data) == 1
|
||||||
assert first_page.has_more is True
|
assert first_page.has_more is True
|
||||||
|
|
||||||
# Get second page using 'after'
|
# Get second page using 'after'
|
||||||
second_page = await provider.list_batches(limit=1, after=first_page.data[0].id)
|
second_page = await provider.list_batches(ListBatchesRequest(limit=1, after=first_page.data[0].id))
|
||||||
assert len(second_page.data) == 1
|
assert len(second_page.data) == 1
|
||||||
assert second_page.data[0].id != first_page.data[0].id
|
assert second_page.data[0].id != first_page.data[0].id
|
||||||
|
|
||||||
# Verify we got the next batch in order
|
# Verify we got the next batch in order
|
||||||
all_batches = await provider.list_batches()
|
all_batches = await provider.list_batches(ListBatchesRequest())
|
||||||
expected_second_batch_id = all_batches.data[1].id
|
expected_second_batch_id = all_batches.data[1].id
|
||||||
assert second_page.data[0].id == expected_second_batch_id
|
assert second_page.data[0].id == expected_second_batch_id
|
||||||
|
|
||||||
async def test_list_batches_invalid_after(self, provider, sample_batch_data):
|
async def test_list_batches_invalid_after(self, provider, sample_batch_data):
|
||||||
"""Test listing batches with invalid 'after' parameter."""
|
"""Test listing batches with invalid 'after' parameter."""
|
||||||
await provider.create_batch(**sample_batch_data)
|
await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||||
|
|
||||||
response = await provider.list_batches(after="nonexistent_batch")
|
response = await provider.list_batches(ListBatchesRequest(after="nonexistent_batch"))
|
||||||
|
|
||||||
# Should return all batches (no filtering when 'after' batch not found)
|
# Should return all batches (no filtering when 'after' batch not found)
|
||||||
assert len(response.data) == 1
|
assert len(response.data) == 1
|
||||||
|
|
||||||
async def test_kvstore_persistence(self, provider, sample_batch_data):
|
async def test_kvstore_persistence(self, provider, sample_batch_data):
|
||||||
"""Test that batches are properly persisted in kvstore."""
|
"""Test that batches are properly persisted in kvstore."""
|
||||||
batch = await provider.create_batch(**sample_batch_data)
|
batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||||
|
|
||||||
stored_data = await provider.kvstore.get(f"batch:{batch.id}")
|
stored_data = await provider.kvstore.get(f"batch:{batch.id}")
|
||||||
assert stored_data is not None
|
assert stored_data is not None
|
||||||
|
|
@ -757,7 +768,7 @@ class TestReferenceBatchesImpl:
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
await provider.create_batch(
|
await provider.create_batch(
|
||||||
input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h"
|
CreateBatchRequest(input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h")
|
||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.sleep(0.042) # let tasks start
|
await asyncio.sleep(0.042) # let tasks start
|
||||||
|
|
@ -767,8 +778,10 @@ class TestReferenceBatchesImpl:
|
||||||
async def test_create_batch_embeddings_endpoint(self, provider):
|
async def test_create_batch_embeddings_endpoint(self, provider):
|
||||||
"""Test that batch creation succeeds with embeddings endpoint."""
|
"""Test that batch creation succeeds with embeddings endpoint."""
|
||||||
batch = await provider.create_batch(
|
batch = await provider.create_batch(
|
||||||
|
CreateBatchRequest(
|
||||||
input_file_id="file_123",
|
input_file_id="file_123",
|
||||||
endpoint="/v1/embeddings",
|
endpoint="/v1/embeddings",
|
||||||
completion_window="24h",
|
completion_window="24h",
|
||||||
)
|
)
|
||||||
|
)
|
||||||
assert batch.endpoint == "/v1/embeddings"
|
assert batch.endpoint == "/v1/embeddings"
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ import asyncio
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack_api import ConflictError
|
from llama_stack_api import ConflictError
|
||||||
|
from llama_stack_api.batches.models import CreateBatchRequest, RetrieveBatchRequest
|
||||||
|
|
||||||
|
|
||||||
class TestReferenceBatchesIdempotency:
|
class TestReferenceBatchesIdempotency:
|
||||||
|
|
@ -56,19 +57,23 @@ class TestReferenceBatchesIdempotency:
|
||||||
del sample_batch_data["metadata"]
|
del sample_batch_data["metadata"]
|
||||||
|
|
||||||
batch1 = await provider.create_batch(
|
batch1 = await provider.create_batch(
|
||||||
|
CreateBatchRequest(
|
||||||
**sample_batch_data,
|
**sample_batch_data,
|
||||||
metadata={"test": "value1", "other": "value2"},
|
metadata={"test": "value1", "other": "value2"},
|
||||||
idempotency_key="unique-token-1",
|
idempotency_key="unique-token-1",
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# sleep for 1 second to allow created_at timestamps to be different
|
# sleep for 1 second to allow created_at timestamps to be different
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
batch2 = await provider.create_batch(
|
batch2 = await provider.create_batch(
|
||||||
|
CreateBatchRequest(
|
||||||
**sample_batch_data,
|
**sample_batch_data,
|
||||||
metadata={"other": "value2", "test": "value1"}, # Different order
|
metadata={"other": "value2", "test": "value1"}, # Different order
|
||||||
idempotency_key="unique-token-1",
|
idempotency_key="unique-token-1",
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
assert batch1.id == batch2.id
|
assert batch1.id == batch2.id
|
||||||
assert batch1.input_file_id == batch2.input_file_id
|
assert batch1.input_file_id == batch2.input_file_id
|
||||||
|
|
@ -77,23 +82,17 @@ class TestReferenceBatchesIdempotency:
|
||||||
|
|
||||||
async def test_different_idempotency_keys_create_different_batches(self, provider, sample_batch_data):
|
async def test_different_idempotency_keys_create_different_batches(self, provider, sample_batch_data):
|
||||||
"""Test that different idempotency keys create different batches even with same params."""
|
"""Test that different idempotency keys create different batches even with same params."""
|
||||||
batch1 = await provider.create_batch(
|
batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data, idempotency_key="token-A"))
|
||||||
**sample_batch_data,
|
|
||||||
idempotency_key="token-A",
|
|
||||||
)
|
|
||||||
|
|
||||||
batch2 = await provider.create_batch(
|
batch2 = await provider.create_batch(CreateBatchRequest(**sample_batch_data, idempotency_key="token-B"))
|
||||||
**sample_batch_data,
|
|
||||||
idempotency_key="token-B",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert batch1.id != batch2.id
|
assert batch1.id != batch2.id
|
||||||
|
|
||||||
async def test_non_idempotent_behavior_without_key(self, provider, sample_batch_data):
|
async def test_non_idempotent_behavior_without_key(self, provider, sample_batch_data):
|
||||||
"""Test that batches without idempotency key create unique batches even with identical parameters."""
|
"""Test that batches without idempotency key create unique batches even with identical parameters."""
|
||||||
batch1 = await provider.create_batch(**sample_batch_data)
|
batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||||
|
|
||||||
batch2 = await provider.create_batch(**sample_batch_data)
|
batch2 = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||||
|
|
||||||
assert batch1.id != batch2.id
|
assert batch1.id != batch2.id
|
||||||
assert batch1.input_file_id == batch2.input_file_id
|
assert batch1.input_file_id == batch2.input_file_id
|
||||||
|
|
@ -117,12 +116,12 @@ class TestReferenceBatchesIdempotency:
|
||||||
|
|
||||||
sample_batch_data[param_name] = first_value
|
sample_batch_data[param_name] = first_value
|
||||||
|
|
||||||
batch1 = await provider.create_batch(**sample_batch_data)
|
batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||||
|
|
||||||
with pytest.raises(ConflictError, match="Idempotency key.*was previously used with different parameters"):
|
with pytest.raises(ConflictError, match="Idempotency key.*was previously used with different parameters"):
|
||||||
sample_batch_data[param_name] = second_value
|
sample_batch_data[param_name] = second_value
|
||||||
await provider.create_batch(**sample_batch_data)
|
await provider.create_batch(CreateBatchRequest(**sample_batch_data))
|
||||||
|
|
||||||
retrieved_batch = await provider.retrieve_batch(batch1.id)
|
retrieved_batch = await provider.retrieve_batch(RetrieveBatchRequest(batch_id=batch1.id))
|
||||||
assert retrieved_batch.id == batch1.id
|
assert retrieved_batch.id == batch1.id
|
||||||
assert getattr(retrieved_batch, param_name) == first_value
|
assert getattr(retrieved_batch, param_name) == first_value
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue