Merge branch 'main' into chroma

This commit is contained in:
Bwook (Byoungwook) Kim 2025-08-25 09:59:32 +09:00 committed by GitHub
commit d3958fae4f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
192 changed files with 7088 additions and 853 deletions

View file

@ -5,86 +5,121 @@
# the root directory of this source tree.
"""
Unit tests for LlamaStackAsLibraryClient initialization error handling.
Unit tests for LlamaStackAsLibraryClient automatic initialization.
These tests ensure that users get proper error messages when they forget to call
initialize() on the library client, preventing AttributeError regressions.
These tests ensure that the library client is automatically initialized
and ready to use immediately after construction.
"""
import pytest
from llama_stack.core.library_client import (
AsyncLlamaStackAsLibraryClient,
LlamaStackAsLibraryClient,
)
from llama_stack.core.server.routes import RouteImpls
class TestLlamaStackAsLibraryClientInitialization:
"""Test proper error handling for uninitialized library clients."""
class TestLlamaStackAsLibraryClientAutoInitialization:
"""Test automatic initialization of library clients."""
@pytest.mark.parametrize(
"api_call",
[
lambda client: client.models.list(),
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
lambda client: next(
client.chat.completions.create(
model="test", messages=[{"role": "user", "content": "test"}], stream=True
)
),
],
ids=["models.list", "chat.completions.create", "chat.completions.create_stream"],
)
def test_sync_client_proper_error_without_initialization(self, api_call):
"""Test that sync client raises ValueError with helpful message when not initialized."""
client = LlamaStackAsLibraryClient("nvidia")
def test_sync_client_auto_initialization(self, monkeypatch):
"""Test that sync client is automatically initialized after construction."""
# Mock the stack construction to avoid dependency issues
mock_impls = {}
mock_route_impls = RouteImpls({})
with pytest.raises(ValueError) as exc_info:
api_call(client)
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
error_msg = str(exc_info.value)
assert "Client not initialized" in error_msg
assert "Please call initialize() first" in error_msg
def mock_initialize_route_impls(impls):
return mock_route_impls
@pytest.mark.parametrize(
"api_call",
[
lambda client: client.models.list(),
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
],
ids=["models.list", "chat.completions.create"],
)
async def test_async_client_proper_error_without_initialization(self, api_call):
"""Test that async client raises ValueError with helpful message when not initialized."""
client = AsyncLlamaStackAsLibraryClient("nvidia")
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
with pytest.raises(ValueError) as exc_info:
await api_call(client)
client = LlamaStackAsLibraryClient("ci-tests")
error_msg = str(exc_info.value)
assert "Client not initialized" in error_msg
assert "Please call initialize() first" in error_msg
assert client.async_client.route_impls is not None
async def test_async_client_streaming_error_without_initialization(self):
"""Test that async client streaming raises ValueError with helpful message when not initialized."""
client = AsyncLlamaStackAsLibraryClient("nvidia")
async def test_async_client_auto_initialization(self, monkeypatch):
"""Test that async client can be initialized and works properly."""
# Mock the stack construction to avoid dependency issues
mock_impls = {}
mock_route_impls = RouteImpls({})
with pytest.raises(ValueError) as exc_info:
stream = await client.chat.completions.create(
model="test", messages=[{"role": "user", "content": "test"}], stream=True
)
await anext(stream)
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
error_msg = str(exc_info.value)
assert "Client not initialized" in error_msg
assert "Please call initialize() first" in error_msg
def mock_initialize_route_impls(impls):
return mock_route_impls
def test_route_impls_initialized_to_none(self):
"""Test that route_impls is initialized to None to prevent AttributeError."""
# Test sync client
sync_client = LlamaStackAsLibraryClient("nvidia")
assert sync_client.async_client.route_impls is None
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
# Test async client directly
async_client = AsyncLlamaStackAsLibraryClient("nvidia")
assert async_client.route_impls is None
client = AsyncLlamaStackAsLibraryClient("ci-tests")
# Initialize the client
result = await client.initialize()
assert result is True
assert client.route_impls is not None
def test_initialize_method_backward_compatibility(self, monkeypatch):
"""Test that initialize() method still works for backward compatibility."""
# Mock the stack construction to avoid dependency issues
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests")
result = client.initialize()
assert result is None
result2 = client.initialize()
assert result2 is None
async def test_async_initialize_method_idempotent(self, monkeypatch):
"""Test that async initialize() method can be called multiple times safely."""
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests")
result1 = await client.initialize()
assert result1 is True
result2 = await client.initialize()
assert result2 is True
def test_route_impls_automatically_set(self, monkeypatch):
"""Test that route_impls is automatically set during construction."""
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
sync_client = LlamaStackAsLibraryClient("ci-tests")
assert sync_client.async_client.route_impls is not None

View file

@ -7,6 +7,7 @@
import pytest
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack.core.access_control.access_control import default_policy
@ -190,7 +191,7 @@ class TestOpenAIFilesAPI:
async def test_retrieve_file_not_found(self, files_provider):
"""Test retrieving a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
with pytest.raises(ResourceNotFoundError, match="not found"):
await files_provider.openai_retrieve_file("file-nonexistent")
async def test_retrieve_file_content_success(self, files_provider, sample_text_file):
@ -208,7 +209,7 @@ class TestOpenAIFilesAPI:
async def test_retrieve_file_content_not_found(self, files_provider):
"""Test retrieving content of a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
with pytest.raises(ResourceNotFoundError, match="not found"):
await files_provider.openai_retrieve_file_content("file-nonexistent")
async def test_delete_file_success(self, files_provider, sample_text_file):
@ -229,12 +230,12 @@ class TestOpenAIFilesAPI:
assert delete_response.deleted is True
# Verify file no longer exists
with pytest.raises(ValueError, match=f"File with id {uploaded_file.id} not found"):
with pytest.raises(ResourceNotFoundError, match="not found"):
await files_provider.openai_retrieve_file(uploaded_file.id)
async def test_delete_file_not_found(self, files_provider):
"""Test deleting a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
with pytest.raises(ResourceNotFoundError, match="not found"):
await files_provider.openai_delete_file("file-nonexistent")
async def test_file_persistence_across_operations(self, files_provider, sample_text_file):

View file

@ -24,6 +24,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseMessage,
OpenAIResponseObjectWithInput,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
@ -461,6 +462,53 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
assert input[3].content == "fake_input"
async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mock_responses_store):
"""Test prepending a previous response which included an mcp tool call to a new response."""
input_item_message = OpenAIResponseMessage(
id="123",
content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")],
role="user",
)
output_tool_call = OpenAIResponseOutputMessageMCPCall(
id="ws_123",
name="fake-tool",
arguments="fake-arguments",
server_label="fake-label",
)
output_message = OpenAIResponseMessage(
id="123",
content=[OpenAIResponseOutputMessageContentOutputText(text="fake_tool_call_response")],
status="completed",
role="assistant",
)
response = OpenAIResponseObjectWithInput(
created_at=1,
id="resp_123",
model="fake_model",
output=[output_tool_call, output_message],
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
)
mock_responses_store.get_response_object.return_value = response
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
assert len(input) == 4
# Check for previous input
assert isinstance(input[0], OpenAIResponseMessage)
assert input[0].content[0].text == "fake_previous_input"
# Check for previous output MCP tool call
assert isinstance(input[1], OpenAIResponseOutputMessageMCPCall)
# Check for previous output web search response
assert isinstance(input[2], OpenAIResponseMessage)
assert input[2].content[0].text == "fake_tool_call_response"
# Check for new input
assert isinstance(input[3], OpenAIResponseMessage)
assert input[3].content == "fake_input"
async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api):
# Setup
input_text = "What is the capital of Ireland?"

View file

@ -115,18 +115,27 @@ class TestConvertResponseInputToChatMessages:
async def test_convert_function_tool_call_output(self):
input_items = [
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_123",
name="test_function",
arguments='{"param": "value"}',
),
OpenAIResponseInputFunctionToolCallOutput(
output="Tool output",
call_id="call_123",
)
),
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 1
assert isinstance(result[0], OpenAIToolMessageParam)
assert result[0].content == "Tool output"
assert result[0].tool_call_id == "call_123"
assert len(result) == 2
assert isinstance(result[0], OpenAIAssistantMessageParam)
assert result[0].tool_calls[0].id == "call_123"
assert result[0].tool_calls[0].function.name == "test_function"
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
assert isinstance(result[1], OpenAIToolMessageParam)
assert result[1].content == "Tool output"
assert result[1].tool_call_id == "call_123"
async def test_convert_function_tool_call(self):
input_items = [
@ -146,6 +155,47 @@ class TestConvertResponseInputToChatMessages:
assert result[0].tool_calls[0].function.name == "test_function"
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
async def test_convert_function_call_ordering(self):
input_items = [
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_123",
name="test_function_a",
arguments='{"param": "value"}',
),
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_456",
name="test_function_b",
arguments='{"param": "value"}',
),
OpenAIResponseInputFunctionToolCallOutput(
output="AAA",
call_id="call_123",
),
OpenAIResponseInputFunctionToolCallOutput(
output="BBB",
call_id="call_456",
),
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 4
assert isinstance(result[0], OpenAIAssistantMessageParam)
assert len(result[0].tool_calls) == 1
assert result[0].tool_calls[0].id == "call_123"
assert result[0].tool_calls[0].function.name == "test_function_a"
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
assert isinstance(result[1], OpenAIToolMessageParam)
assert result[1].content == "AAA"
assert result[1].tool_call_id == "call_123"
assert isinstance(result[2], OpenAIAssistantMessageParam)
assert len(result[2].tool_calls) == 1
assert result[2].tool_calls[0].id == "call_456"
assert result[2].tool_calls[0].function.name == "test_function_b"
assert result[2].tool_calls[0].function.arguments == '{"param": "value"}'
assert isinstance(result[3], OpenAIToolMessageParam)
assert result[3].content == "BBB"
assert result[3].tool_call_id == "call_456"
async def test_convert_response_message(self):
input_items = [
OpenAIResponseMessage(

View file

@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Shared fixtures for batches provider unit tests."""
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@pytest.fixture
async def provider():
"""Create a test provider instance with temporary database."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_batches.db"
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
# Create kvstore and mock APIs
kvstore = await kvstore_impl(config.kvstore)
mock_inference = AsyncMock()
mock_files = AsyncMock()
mock_models = AsyncMock()
provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
await provider.initialize()
# unit tests should not require background processing
provider.process_batches = False
yield provider
await provider.shutdown()
@pytest.fixture
def sample_batch_data():
"""Sample batch data for testing."""
return {
"input_file_id": "file_abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"metadata": {"test": "true", "priority": "high"},
}

View file

@ -54,60 +54,17 @@ dependencies like inference, files, and models APIs.
"""
import json
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.batches import BatchObject
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class TestReferenceBatchesImpl:
"""Test the reference implementation of the Batches API."""
@pytest.fixture
async def provider(self):
"""Create a test provider instance with temporary database."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_batches.db"
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
# Create kvstore and mock APIs
from unittest.mock import AsyncMock
from llama_stack.providers.utils.kvstore import kvstore_impl
kvstore = await kvstore_impl(config.kvstore)
mock_inference = AsyncMock()
mock_files = AsyncMock()
mock_models = AsyncMock()
provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
await provider.initialize()
# unit tests should not require background processing
provider.process_batches = False
yield provider
await provider.shutdown()
@pytest.fixture
def sample_batch_data(self):
"""Sample batch data for testing."""
return {
"input_file_id": "file_abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"metadata": {"test": "true", "priority": "high"},
}
def _validate_batch_type(self, batch, expected_metadata=None):
"""
Helper function to validate batch object structure and field types.

View file

@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Tests for idempotency functionality in the reference batches provider.
This module tests the optional idempotency feature that allows clients to provide
an idempotency key (idempotency_key) to ensure that repeated requests with the same key
and parameters return the same batch, while requests with the same key but different
parameters result in a conflict error.
Test Categories:
1. Core Idempotency: Same parameters with same key return same batch
2. Parameter Independence: Different parameters without keys create different batches
3. Conflict Detection: Same key with different parameters raises ConflictError
Tests by Category:
1. Core Idempotency:
- test_idempotent_batch_creation_same_params
- test_idempotent_batch_creation_metadata_order_independence
2. Parameter Independence:
- test_non_idempotent_behavior_without_key
- test_different_idempotency_keys_create_different_batches
3. Conflict Detection:
- test_same_idempotency_key_different_params_conflict (parametrized: input_file_id, metadata values, metadata None vs {})
Key Behaviors Tested:
- Idempotent batch creation when idempotency_key provided with identical parameters
- Metadata order independence for consistent batch ID generation
- Non-idempotent behavior when no idempotency_key provided (random UUIDs)
- Conflict detection for parameter mismatches with same idempotency key
- Deterministic ID generation based solely on idempotency key
- Proper error handling with detailed conflict messages including key and error codes
- Protection against idempotency key reuse with different request parameters
"""
import asyncio
import pytest
from llama_stack.apis.common.errors import ConflictError
class TestReferenceBatchesIdempotency:
"""Test suite for idempotency functionality in the reference implementation."""
async def test_idempotent_batch_creation_same_params(self, provider, sample_batch_data):
"""Test that creating batches with identical parameters returns the same batch when idempotency_key is provided."""
del sample_batch_data["metadata"]
batch1 = await provider.create_batch(
**sample_batch_data,
metadata={"test": "value1", "other": "value2"},
idempotency_key="unique-token-1",
)
# sleep for 1 second to allow created_at timestamps to be different
await asyncio.sleep(1)
batch2 = await provider.create_batch(
**sample_batch_data,
metadata={"other": "value2", "test": "value1"}, # Different order
idempotency_key="unique-token-1",
)
assert batch1.id == batch2.id
assert batch1.input_file_id == batch2.input_file_id
assert batch1.metadata == batch2.metadata
assert batch1.created_at == batch2.created_at
async def test_different_idempotency_keys_create_different_batches(self, provider, sample_batch_data):
"""Test that different idempotency keys create different batches even with same params."""
batch1 = await provider.create_batch(
**sample_batch_data,
idempotency_key="token-A",
)
batch2 = await provider.create_batch(
**sample_batch_data,
idempotency_key="token-B",
)
assert batch1.id != batch2.id
async def test_non_idempotent_behavior_without_key(self, provider, sample_batch_data):
"""Test that batches without idempotency key create unique batches even with identical parameters."""
batch1 = await provider.create_batch(**sample_batch_data)
batch2 = await provider.create_batch(**sample_batch_data)
assert batch1.id != batch2.id
assert batch1.input_file_id == batch2.input_file_id
assert batch1.endpoint == batch2.endpoint
assert batch1.completion_window == batch2.completion_window
assert batch1.metadata == batch2.metadata
@pytest.mark.parametrize(
"param_name,first_value,second_value",
[
("input_file_id", "file_001", "file_002"),
("metadata", {"test": "value1"}, {"test": "value2"}),
("metadata", None, {}),
],
)
async def test_same_idempotency_key_different_params_conflict(
self, provider, sample_batch_data, param_name, first_value, second_value
):
"""Test that same idempotency_key with different parameters raises conflict error."""
sample_batch_data["idempotency_key"] = "same-token"
sample_batch_data[param_name] = first_value
batch1 = await provider.create_batch(**sample_batch_data)
with pytest.raises(ConflictError, match="Idempotency key.*was previously used with different parameters"):
sample_batch_data[param_name] = second_value
await provider.create_batch(**sample_batch_data)
retrieved_batch = await provider.retrieve_batch(batch1.id)
assert retrieved_batch.id == batch1.id
assert getattr(retrieved_batch, param_name) == first_value

View file

@ -0,0 +1,251 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import patch
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_aws
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack.providers.remote.files.s3 import (
S3FilesImplConfig,
get_adapter_impl,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
class MockUploadFile:
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
self.content = content
self.filename = filename
self.content_type = content_type
async def read(self):
return self.content
@pytest.fixture
def s3_config(tmp_path):
db_path = tmp_path / "s3_files_metadata.db"
return S3FilesImplConfig(
bucket_name="test-bucket",
region="not-a-region",
auto_create_bucket=True,
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
)
@pytest.fixture
def s3_client():
"""Create a mocked S3 client for testing."""
# we use `with mock_aws()` because @mock_aws decorator does not support being a generator
with mock_aws():
# must yield or the mock will be reset before it is used
yield boto3.client("s3")
@pytest.fixture
async def s3_provider(s3_config, s3_client):
"""Create an S3 files provider with mocked S3 for testing."""
provider = await get_adapter_impl(s3_config, {})
yield provider
await provider.shutdown()
@pytest.fixture
def sample_text_file():
content = b"Hello, this is a test file for the S3 Files API!"
return MockUploadFile(content, "sample_text_file.txt")
class TestS3FilesImpl:
"""Test suite for S3 Files implementation."""
async def test_upload_file(self, s3_provider, sample_text_file, s3_client, s3_config):
"""Test successful file upload."""
sample_text_file.filename = "test_upload_file"
result = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
assert result.filename == sample_text_file.filename
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
assert result.bytes == len(sample_text_file.content)
assert result.id.startswith("file-")
# Verify file exists in S3 backend
response = s3_client.head_object(Bucket=s3_config.bucket_name, Key=result.id)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
async def test_list_files_empty(self, s3_provider):
"""Test listing files when no files exist."""
result = await s3_provider.openai_list_files()
assert len(result.data) == 0
assert not result.has_more
assert result.first_id == ""
assert result.last_id == ""
async def test_retrieve_file(self, s3_provider, sample_text_file):
"""Test retrieving file metadata."""
sample_text_file.filename = "test_retrieve_file"
uploaded = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
retrieved = await s3_provider.openai_retrieve_file(uploaded.id)
assert retrieved.id == uploaded.id
assert retrieved.filename == uploaded.filename
assert retrieved.purpose == uploaded.purpose
assert retrieved.bytes == uploaded.bytes
async def test_retrieve_file_content(self, s3_provider, sample_text_file):
"""Test retrieving file content."""
sample_text_file.filename = "test_retrieve_file_content"
uploaded = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
response = await s3_provider.openai_retrieve_file_content(uploaded.id)
assert response.body == sample_text_file.content
assert response.headers["Content-Disposition"] == f'attachment; filename="{sample_text_file.filename}"'
async def test_delete_file(self, s3_provider, sample_text_file, s3_config, s3_client):
"""Test deleting a file."""
sample_text_file.filename = "test_delete_file"
uploaded = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
delete_response = await s3_provider.openai_delete_file(uploaded.id)
assert delete_response.id == uploaded.id
assert delete_response.deleted is True
with pytest.raises(ResourceNotFoundError, match="not found"):
await s3_provider.openai_retrieve_file(uploaded.id)
# Verify file is gone from S3 backend
with pytest.raises(ClientError) as exc_info:
s3_client.head_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
assert exc_info.value.response["Error"]["Code"] == "404"
async def test_list_files(self, s3_provider, sample_text_file):
"""Test listing files after uploading some."""
sample_text_file.filename = "test_list_files_with_content_file1"
file1 = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
file2_content = MockUploadFile(b"Second file content", "test_list_files_with_content_file2")
file2 = await s3_provider.openai_upload_file(
file=file2_content,
purpose=OpenAIFilePurpose.BATCH,
)
result = await s3_provider.openai_list_files()
assert len(result.data) == 2
file_ids = {f.id for f in result.data}
assert file1.id in file_ids
assert file2.id in file_ids
async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file):
"""Test listing files with purpose filter."""
sample_text_file.filename = "test_list_files_with_purpose_filter_file1"
file1 = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
file2_content = MockUploadFile(b"Batch file content", "test_list_files_with_purpose_filter_file2")
await s3_provider.openai_upload_file(
file=file2_content,
purpose=OpenAIFilePurpose.BATCH,
)
result = await s3_provider.openai_list_files(purpose=OpenAIFilePurpose.ASSISTANTS)
assert len(result.data) == 1
assert result.data[0].id == file1.id
assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS
async def test_nonexistent_file_retrieval(self, s3_provider):
"""Test retrieving a non-existent file raises error."""
with pytest.raises(ResourceNotFoundError, match="not found"):
await s3_provider.openai_retrieve_file("file-nonexistent")
async def test_nonexistent_file_content_retrieval(self, s3_provider):
"""Test retrieving content of a non-existent file raises error."""
with pytest.raises(ResourceNotFoundError, match="not found"):
await s3_provider.openai_retrieve_file_content("file-nonexistent")
async def test_nonexistent_file_deletion(self, s3_provider):
"""Test deleting a non-existent file raises error."""
with pytest.raises(ResourceNotFoundError, match="not found"):
await s3_provider.openai_delete_file("file-nonexistent")
async def test_upload_file_without_filename(self, s3_provider, sample_text_file):
"""Test uploading a file without a filename uses the fallback."""
del sample_text_file.filename
result = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
assert result.bytes == len(sample_text_file.content)
retrieved = await s3_provider.openai_retrieve_file(result.id)
assert retrieved.filename == result.filename
async def test_file_operations_when_s3_object_deleted(self, s3_provider, sample_text_file, s3_config, s3_client):
"""Test file operations when S3 object is deleted but metadata exists (negative test)."""
sample_text_file.filename = "test_orphaned_metadata"
uploaded = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
# Directly delete the S3 object from the backend
s3_client.delete_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
with pytest.raises(ResourceNotFoundError, match="not found") as exc_info:
await s3_provider.openai_retrieve_file_content(uploaded.id)
assert uploaded.id in str(exc_info).lower()
listed_files = await s3_provider.openai_list_files()
assert uploaded.id not in [file.id for file in listed_files.data]
async def test_upload_file_s3_put_object_failure(self, s3_provider, sample_text_file, s3_config, s3_client):
"""Test that put_object failure results in exception and no orphaned metadata."""
sample_text_file.filename = "test_s3_put_object_failure"
def failing_put_object(*args, **kwargs):
raise ClientError(
error_response={"Error": {"Code": "SolarRadiation", "Message": "Bloop"}}, operation_name="PutObject"
)
with patch.object(s3_provider.client, "put_object", side_effect=failing_put_object):
with pytest.raises(RuntimeError, match="Failed to upload file to S3"):
await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
files_list = await s3_provider.openai_list_files()
assert len(files_list.data) == 0, "No file metadata should remain after failed upload"

View file

@ -6,7 +6,7 @@
import asyncio
import json
import logging
import logging # allow-direct-logging
import threading
import time
from http.server import BaseHTTPRequestHandler, HTTPServer

View file

@ -0,0 +1,105 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.core.datatypes import CORSConfig, process_cors_config
def test_cors_config_defaults():
config = CORSConfig()
assert config.allow_origins == []
assert config.allow_origin_regex is None
assert config.allow_methods == ["OPTIONS"]
assert config.allow_headers == []
assert config.allow_credentials is False
assert config.expose_headers == []
assert config.max_age == 600
def test_cors_config_explicit_config():
config = CORSConfig(
allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"]
)
assert config.allow_origins == ["https://example.com"]
assert config.allow_credentials is True
assert config.max_age == 3600
assert config.allow_methods == ["GET", "POST"]
def test_cors_config_regex():
config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
assert config.allow_origins == []
assert config.allow_origin_regex == r"https?://localhost:\d+"
def test_cors_config_wildcard_credentials_error():
with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
CORSConfig(allow_origins=["*"], allow_credentials=True)
with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
CORSConfig(allow_origins=["https://example.com", "*"], allow_credentials=True)
def test_process_cors_config_false():
result = process_cors_config(False)
assert result is None
def test_process_cors_config_true():
result = process_cors_config(True)
assert isinstance(result, CORSConfig)
assert result.allow_origins == []
assert result.allow_origin_regex == r"https?://localhost:\d+"
assert result.allow_credentials is False
expected_methods = ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
for method in expected_methods:
assert method in result.allow_methods
def test_process_cors_config_passthrough():
original = CORSConfig(allow_origins=["https://example.com"], allow_methods=["GET"])
result = process_cors_config(original)
assert result is original
def test_process_cors_config_invalid_type():
with pytest.raises(ValueError, match="Expected bool or CORSConfig, got str"):
process_cors_config("invalid")
def test_cors_config_model_dump():
cors_config = CORSConfig(
allow_origins=["https://example.com"],
allow_methods=["GET", "POST"],
allow_headers=["Content-Type"],
allow_credentials=True,
max_age=3600,
)
config_dict = cors_config.model_dump()
assert config_dict["allow_origins"] == ["https://example.com"]
assert config_dict["allow_methods"] == ["GET", "POST"]
assert config_dict["allow_headers"] == ["Content-Type"]
assert config_dict["allow_credentials"] is True
assert config_dict["max_age"] == 3600
expected_keys = {
"allow_origins",
"allow_origin_regex",
"allow_methods",
"allow_headers",
"allow_credentials",
"expose_headers",
"max_age",
}
assert set(config_dict.keys()) == expected_keys