Merge branch 'main' into content-extension

This commit is contained in:
Francisco Arceo 2025-08-25 14:22:15 -06:00 committed by GitHub
commit 3e11e1472c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
334 changed files with 22841 additions and 8940 deletions

View file

@ -5,86 +5,121 @@
# the root directory of this source tree.
"""
Unit tests for LlamaStackAsLibraryClient initialization error handling.
Unit tests for LlamaStackAsLibraryClient automatic initialization.
These tests ensure that users get proper error messages when they forget to call
initialize() on the library client, preventing AttributeError regressions.
These tests ensure that the library client is automatically initialized
and ready to use immediately after construction.
"""
import pytest
from llama_stack.core.library_client import (
AsyncLlamaStackAsLibraryClient,
LlamaStackAsLibraryClient,
)
from llama_stack.core.server.routes import RouteImpls
class TestLlamaStackAsLibraryClientInitialization:
"""Test proper error handling for uninitialized library clients."""
class TestLlamaStackAsLibraryClientAutoInitialization:
"""Test automatic initialization of library clients."""
@pytest.mark.parametrize(
"api_call",
[
lambda client: client.models.list(),
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
lambda client: next(
client.chat.completions.create(
model="test", messages=[{"role": "user", "content": "test"}], stream=True
)
),
],
ids=["models.list", "chat.completions.create", "chat.completions.create_stream"],
)
def test_sync_client_proper_error_without_initialization(self, api_call):
"""Test that sync client raises ValueError with helpful message when not initialized."""
client = LlamaStackAsLibraryClient("nvidia")
def test_sync_client_auto_initialization(self, monkeypatch):
"""Test that sync client is automatically initialized after construction."""
# Mock the stack construction to avoid dependency issues
mock_impls = {}
mock_route_impls = RouteImpls({})
with pytest.raises(ValueError) as exc_info:
api_call(client)
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
error_msg = str(exc_info.value)
assert "Client not initialized" in error_msg
assert "Please call initialize() first" in error_msg
def mock_initialize_route_impls(impls):
return mock_route_impls
@pytest.mark.parametrize(
"api_call",
[
lambda client: client.models.list(),
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
],
ids=["models.list", "chat.completions.create"],
)
async def test_async_client_proper_error_without_initialization(self, api_call):
"""Test that async client raises ValueError with helpful message when not initialized."""
client = AsyncLlamaStackAsLibraryClient("nvidia")
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
with pytest.raises(ValueError) as exc_info:
await api_call(client)
client = LlamaStackAsLibraryClient("ci-tests")
error_msg = str(exc_info.value)
assert "Client not initialized" in error_msg
assert "Please call initialize() first" in error_msg
assert client.async_client.route_impls is not None
async def test_async_client_streaming_error_without_initialization(self):
"""Test that async client streaming raises ValueError with helpful message when not initialized."""
client = AsyncLlamaStackAsLibraryClient("nvidia")
async def test_async_client_auto_initialization(self, monkeypatch):
"""Test that async client can be initialized and works properly."""
# Mock the stack construction to avoid dependency issues
mock_impls = {}
mock_route_impls = RouteImpls({})
with pytest.raises(ValueError) as exc_info:
stream = await client.chat.completions.create(
model="test", messages=[{"role": "user", "content": "test"}], stream=True
)
await anext(stream)
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
error_msg = str(exc_info.value)
assert "Client not initialized" in error_msg
assert "Please call initialize() first" in error_msg
def mock_initialize_route_impls(impls):
return mock_route_impls
def test_route_impls_initialized_to_none(self):
"""Test that route_impls is initialized to None to prevent AttributeError."""
# Test sync client
sync_client = LlamaStackAsLibraryClient("nvidia")
assert sync_client.async_client.route_impls is None
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
# Test async client directly
async_client = AsyncLlamaStackAsLibraryClient("nvidia")
assert async_client.route_impls is None
client = AsyncLlamaStackAsLibraryClient("ci-tests")
# Initialize the client
result = await client.initialize()
assert result is True
assert client.route_impls is not None
def test_initialize_method_backward_compatibility(self, monkeypatch):
"""Test that initialize() method still works for backward compatibility."""
# Mock the stack construction to avoid dependency issues
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests")
result = client.initialize()
assert result is None
result2 = client.initialize()
assert result2 is None
async def test_async_initialize_method_idempotent(self, monkeypatch):
"""Test that async initialize() method can be called multiple times safely."""
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests")
result1 = await client.initialize()
assert result1 is True
result2 = await client.initialize()
assert result2 is True
def test_route_impls_automatically_set(self, monkeypatch):
"""Test that route_impls is automatically set during construction."""
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
sync_client = LlamaStackAsLibraryClient("ci-tests")
assert sync_client.async_client.route_impls is not None

View file

@ -7,6 +7,7 @@
import pytest
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack.core.access_control.access_control import default_policy
@ -190,7 +191,7 @@ class TestOpenAIFilesAPI:
async def test_retrieve_file_not_found(self, files_provider):
"""Test retrieving a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
with pytest.raises(ResourceNotFoundError, match="not found"):
await files_provider.openai_retrieve_file("file-nonexistent")
async def test_retrieve_file_content_success(self, files_provider, sample_text_file):
@ -208,7 +209,7 @@ class TestOpenAIFilesAPI:
async def test_retrieve_file_content_not_found(self, files_provider):
"""Test retrieving content of a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
with pytest.raises(ResourceNotFoundError, match="not found"):
await files_provider.openai_retrieve_file_content("file-nonexistent")
async def test_delete_file_success(self, files_provider, sample_text_file):
@ -229,12 +230,12 @@ class TestOpenAIFilesAPI:
assert delete_response.deleted is True
# Verify file no longer exists
with pytest.raises(ValueError, match=f"File with id {uploaded_file.id} not found"):
with pytest.raises(ResourceNotFoundError, match="not found"):
await files_provider.openai_retrieve_file(uploaded_file.id)
async def test_delete_file_not_found(self, files_provider):
"""Test deleting a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
with pytest.raises(ResourceNotFoundError, match="not found"):
await files_provider.openai_delete_file("file-nonexistent")
async def test_file_persistence_across_operations(self, files_provider, sample_text_file):

View file

@ -24,6 +24,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseMessage,
OpenAIResponseObjectWithInput,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
@ -41,7 +42,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
@ -136,9 +137,12 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
input=input_text,
model=model,
temperature=0.1,
stream=True, # Enable streaming to test content part events
)
# Verify
# For streaming response, collect all chunks
chunks = [chunk async for chunk in result]
mock_inference_api.openai_chat_completion.assert_called_once_with(
model=model,
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
@ -147,11 +151,32 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
stream=True,
temperature=0.1,
)
# Should have content part events for text streaming
# Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed
assert len(chunks) >= 4
assert chunks[0].type == "response.created"
# Check for content part events
content_part_added_events = [c for c in chunks if c.type == "response.content_part.added"]
content_part_done_events = [c for c in chunks if c.type == "response.content_part.done"]
text_delta_events = [c for c in chunks if c.type == "response.output_text.delta"]
assert len(content_part_added_events) >= 1, "Should have content_part.added event for text"
assert len(content_part_done_events) >= 1, "Should have content_part.done event for text"
assert len(text_delta_events) >= 1, "Should have text delta events"
# Verify final event is completion
assert chunks[-1].type == "response.completed"
# When streaming, the final response is in the last chunk
final_response = chunks[-1].response
assert final_response.model == model
assert len(final_response.output) == 1
assert isinstance(final_response.output[0], OpenAIResponseMessage)
openai_responses_impl.responses_store.store_response_object.assert_called_once()
assert result.model == model
assert len(result.output) == 1
assert isinstance(result.output[0], OpenAIResponseMessage)
assert result.output[0].content[0].text == "Dublin"
assert final_response.output[0].content[0].text == "Dublin"
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
@ -272,6 +297,8 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
# Verify event types
# Should have: response.created, output_item.added, function_call_arguments.delta,
# function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 6
@ -435,6 +462,53 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
assert input[3].content == "fake_input"
async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mock_responses_store):
"""Test prepending a previous response which included an mcp tool call to a new response."""
input_item_message = OpenAIResponseMessage(
id="123",
content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")],
role="user",
)
output_tool_call = OpenAIResponseOutputMessageMCPCall(
id="ws_123",
name="fake-tool",
arguments="fake-arguments",
server_label="fake-label",
)
output_message = OpenAIResponseMessage(
id="123",
content=[OpenAIResponseOutputMessageContentOutputText(text="fake_tool_call_response")],
status="completed",
role="assistant",
)
response = OpenAIResponseObjectWithInput(
created_at=1,
id="resp_123",
model="fake_model",
output=[output_tool_call, output_message],
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
)
mock_responses_store.get_response_object.return_value = response
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
assert len(input) == 4
# Check for previous input
assert isinstance(input[0], OpenAIResponseMessage)
assert input[0].content[0].text == "fake_previous_input"
# Check for previous output MCP tool call
assert isinstance(input[1], OpenAIResponseOutputMessageMCPCall)
# Check for previous output web search response
assert isinstance(input[2], OpenAIResponseMessage)
assert input[2].content[0].text == "fake_tool_call_response"
# Check for new input
assert isinstance(input[3], OpenAIResponseMessage)
assert input[3].content == "fake_input"
async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api):
# Setup
input_text = "What is the capital of Ireland?"

View file

@ -0,0 +1,342 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatText,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
convert_chat_choice_to_response_message,
convert_response_content_to_chat_content,
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
get_message_type_by_role,
is_function_tool_call,
)
class TestConvertChatChoiceToResponseMessage:
async def test_convert_string_content(self):
choice = OpenAIChoice(
message=OpenAIAssistantMessageParam(content="Test message"),
finish_reason="stop",
index=0,
)
result = await convert_chat_choice_to_response_message(choice)
assert result.role == "assistant"
assert result.status == "completed"
assert len(result.content) == 1
assert isinstance(result.content[0], OpenAIResponseOutputMessageContentOutputText)
assert result.content[0].text == "Test message"
async def test_convert_text_param_content(self):
choice = OpenAIChoice(
message=OpenAIAssistantMessageParam(
content=[OpenAIChatCompletionContentPartTextParam(text="Test text param")]
),
finish_reason="stop",
index=0,
)
with pytest.raises(ValueError) as exc_info:
await convert_chat_choice_to_response_message(choice)
assert "does not yet support output content type" in str(exc_info.value)
class TestConvertResponseContentToChatContent:
async def test_convert_string_content(self):
result = await convert_response_content_to_chat_content("Simple string")
assert result == "Simple string"
async def test_convert_text_content_parts(self):
content = [
OpenAIResponseInputMessageContentText(text="First part"),
OpenAIResponseOutputMessageContentOutputText(text="Second part"),
]
result = await convert_response_content_to_chat_content(content)
assert len(result) == 2
assert isinstance(result[0], OpenAIChatCompletionContentPartTextParam)
assert result[0].text == "First part"
assert isinstance(result[1], OpenAIChatCompletionContentPartTextParam)
assert result[1].text == "Second part"
async def test_convert_image_content(self):
content = [OpenAIResponseInputMessageContentImage(image_url="https://example.com/image.jpg", detail="high")]
result = await convert_response_content_to_chat_content(content)
assert len(result) == 1
assert isinstance(result[0], OpenAIChatCompletionContentPartImageParam)
assert result[0].image_url.url == "https://example.com/image.jpg"
assert result[0].image_url.detail == "high"
class TestConvertResponseInputToChatMessages:
async def test_convert_string_input(self):
result = await convert_response_input_to_chat_messages("User message")
assert len(result) == 1
assert isinstance(result[0], OpenAIUserMessageParam)
assert result[0].content == "User message"
async def test_convert_function_tool_call_output(self):
input_items = [
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_123",
name="test_function",
arguments='{"param": "value"}',
),
OpenAIResponseInputFunctionToolCallOutput(
output="Tool output",
call_id="call_123",
),
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 2
assert isinstance(result[0], OpenAIAssistantMessageParam)
assert result[0].tool_calls[0].id == "call_123"
assert result[0].tool_calls[0].function.name == "test_function"
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
assert isinstance(result[1], OpenAIToolMessageParam)
assert result[1].content == "Tool output"
assert result[1].tool_call_id == "call_123"
async def test_convert_function_tool_call(self):
input_items = [
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_456",
name="test_function",
arguments='{"param": "value"}',
)
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 1
assert isinstance(result[0], OpenAIAssistantMessageParam)
assert len(result[0].tool_calls) == 1
assert result[0].tool_calls[0].id == "call_456"
assert result[0].tool_calls[0].function.name == "test_function"
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
async def test_convert_function_call_ordering(self):
input_items = [
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_123",
name="test_function_a",
arguments='{"param": "value"}',
),
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_456",
name="test_function_b",
arguments='{"param": "value"}',
),
OpenAIResponseInputFunctionToolCallOutput(
output="AAA",
call_id="call_123",
),
OpenAIResponseInputFunctionToolCallOutput(
output="BBB",
call_id="call_456",
),
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 4
assert isinstance(result[0], OpenAIAssistantMessageParam)
assert len(result[0].tool_calls) == 1
assert result[0].tool_calls[0].id == "call_123"
assert result[0].tool_calls[0].function.name == "test_function_a"
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
assert isinstance(result[1], OpenAIToolMessageParam)
assert result[1].content == "AAA"
assert result[1].tool_call_id == "call_123"
assert isinstance(result[2], OpenAIAssistantMessageParam)
assert len(result[2].tool_calls) == 1
assert result[2].tool_calls[0].id == "call_456"
assert result[2].tool_calls[0].function.name == "test_function_b"
assert result[2].tool_calls[0].function.arguments == '{"param": "value"}'
assert isinstance(result[3], OpenAIToolMessageParam)
assert result[3].content == "BBB"
assert result[3].tool_call_id == "call_456"
async def test_convert_response_message(self):
input_items = [
OpenAIResponseMessage(
role="user",
content=[OpenAIResponseInputMessageContentText(text="User text")],
)
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 1
assert isinstance(result[0], OpenAIUserMessageParam)
# Content should be converted to chat content format
assert len(result[0].content) == 1
assert result[0].content[0].text == "User text"
class TestConvertResponseTextToChatResponseFormat:
async def test_convert_text_format(self):
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatText)
assert result.type == "text"
async def test_convert_json_object_format(self):
text = OpenAIResponseText(format={"type": "json_object"})
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatJSONObject)
async def test_convert_json_schema_format(self):
schema_def = {"type": "object", "properties": {"test": {"type": "string"}}}
text = OpenAIResponseText(
format={
"type": "json_schema",
"name": "test_schema",
"schema": schema_def,
}
)
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatJSONSchema)
assert result.json_schema["name"] == "test_schema"
assert result.json_schema["schema"] == schema_def
async def test_default_text_format(self):
text = OpenAIResponseText()
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatText)
assert result.type == "text"
class TestGetMessageTypeByRole:
async def test_user_role(self):
result = await get_message_type_by_role("user")
assert result == OpenAIUserMessageParam
async def test_system_role(self):
result = await get_message_type_by_role("system")
assert result == OpenAISystemMessageParam
async def test_assistant_role(self):
result = await get_message_type_by_role("assistant")
assert result == OpenAIAssistantMessageParam
async def test_developer_role(self):
result = await get_message_type_by_role("developer")
assert result == OpenAIDeveloperMessageParam
async def test_unknown_role(self):
result = await get_message_type_by_role("unknown")
assert result is None
class TestIsFunctionToolCall:
def test_is_function_tool_call_true(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=OpenAIChatCompletionToolCallFunction(
name="test_function",
arguments="{}",
),
)
tools = [
OpenAIResponseInputToolFunction(
type="function", name="test_function", parameters={"type": "object", "properties": {}}
),
OpenAIResponseInputToolWebSearch(type="web_search"),
]
result = is_function_tool_call(tool_call, tools)
assert result is True
def test_is_function_tool_call_false_different_name(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=OpenAIChatCompletionToolCallFunction(
name="other_function",
arguments="{}",
),
)
tools = [
OpenAIResponseInputToolFunction(
type="function", name="test_function", parameters={"type": "object", "properties": {}}
),
]
result = is_function_tool_call(tool_call, tools)
assert result is False
def test_is_function_tool_call_false_no_function(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=None,
)
tools = [
OpenAIResponseInputToolFunction(
type="function", name="test_function", parameters={"type": "object", "properties": {}}
),
]
result = is_function_tool_call(tool_call, tools)
assert result is False
def test_is_function_tool_call_false_wrong_type(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=OpenAIChatCompletionToolCallFunction(
name="web_search",
arguments="{}",
),
)
tools = [
OpenAIResponseInputToolWebSearch(type="web_search"),
]
result = is_function_tool_call(tool_call, tools)
assert result is False

View file

@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Shared fixtures for batches provider unit tests."""
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock
import pytest
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@pytest.fixture
async def provider():
"""Create a test provider instance with temporary database."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_batches.db"
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
# Create kvstore and mock APIs
kvstore = await kvstore_impl(config.kvstore)
mock_inference = AsyncMock()
mock_files = AsyncMock()
mock_models = AsyncMock()
provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
await provider.initialize()
# unit tests should not require background processing
provider.process_batches = False
yield provider
await provider.shutdown()
@pytest.fixture
def sample_batch_data():
"""Sample batch data for testing."""
return {
"input_file_id": "file_abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"metadata": {"test": "true", "priority": "high"},
}

View file

@ -0,0 +1,710 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Test suite for the reference implementation of the Batches API.
The tests are categorized and outlined below, keep this updated:
- Batch creation with various parameters and validation:
* test_create_and_retrieve_batch_success (positive)
* test_create_batch_without_metadata (positive)
* test_create_batch_completion_window (negative)
* test_create_batch_invalid_endpoints (negative)
* test_create_batch_invalid_metadata (negative)
- Batch retrieval and error handling for non-existent batches:
* test_retrieve_batch_not_found (negative)
- Batch cancellation with proper status transitions:
* test_cancel_batch_success (positive)
* test_cancel_batch_invalid_statuses (negative)
* test_cancel_batch_not_found (negative)
- Batch listing with pagination and filtering:
* test_list_batches_empty (positive)
* test_list_batches_single_batch (positive)
* test_list_batches_multiple_batches (positive)
* test_list_batches_with_limit (positive)
* test_list_batches_with_pagination (positive)
* test_list_batches_invalid_after (negative)
- Data persistence in the underlying key-value store:
* test_kvstore_persistence (positive)
- Batch processing concurrency control:
* test_max_concurrent_batches (positive)
- Input validation testing (direct _validate_input method tests):
* test_validate_input_file_not_found (negative)
* test_validate_input_file_exists_empty_content (positive)
* test_validate_input_file_mixed_valid_invalid_json (mixed)
* test_validate_input_invalid_model (negative)
* test_validate_input_url_mismatch (negative)
* test_validate_input_multiple_errors_per_request (negative)
* test_validate_input_invalid_request_format (negative)
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
The tests use temporary SQLite databases for isolation and mock external
dependencies like inference, files, and models APIs.
"""
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.batches import BatchObject
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
class TestReferenceBatchesImpl:
"""Test the reference implementation of the Batches API."""
def _validate_batch_type(self, batch, expected_metadata=None):
"""
Helper function to validate batch object structure and field types.
Note: This validates the direct BatchObject from the provider, not the
client library response which has a different structure.
Args:
batch: The BatchObject instance to validate.
expected_metadata: Optional expected metadata dictionary to validate against.
"""
assert isinstance(batch.id, str)
assert isinstance(batch.completion_window, str)
assert isinstance(batch.created_at, int)
assert isinstance(batch.endpoint, str)
assert isinstance(batch.input_file_id, str)
assert batch.object == "batch"
assert batch.status in [
"validating",
"failed",
"in_progress",
"finalizing",
"completed",
"expired",
"cancelling",
"cancelled",
]
if expected_metadata is not None:
assert batch.metadata == expected_metadata
timestamp_fields = [
"cancelled_at",
"cancelling_at",
"completed_at",
"expired_at",
"expires_at",
"failed_at",
"finalizing_at",
"in_progress_at",
]
for field in timestamp_fields:
field_value = getattr(batch, field, None)
if field_value is not None:
assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}"
file_id_fields = ["error_file_id", "output_file_id"]
for field in file_id_fields:
field_value = getattr(batch, field, None)
if field_value is not None:
assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}"
if hasattr(batch, "request_counts") and batch.request_counts is not None:
assert isinstance(batch.request_counts.completed, int), (
f"request_counts.completed should be int, got {type(batch.request_counts.completed)}"
)
assert isinstance(batch.request_counts.failed, int), (
f"request_counts.failed should be int, got {type(batch.request_counts.failed)}"
)
assert isinstance(batch.request_counts.total, int), (
f"request_counts.total should be int, got {type(batch.request_counts.total)}"
)
if hasattr(batch, "errors") and batch.errors is not None:
assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}"
if hasattr(batch.errors, "data") and batch.errors.data is not None:
assert isinstance(batch.errors.data, list), (
f"errors.data should be list or None, got {type(batch.errors.data)}"
)
for i, error_item in enumerate(batch.errors.data):
assert isinstance(error_item, dict), (
f"errors.data[{i}] should be object or dict, got {type(error_item)}"
)
if hasattr(error_item, "code") and error_item.code is not None:
assert isinstance(error_item.code, str), (
f"errors.data[{i}].code should be str or None, got {type(error_item.code)}"
)
if hasattr(error_item, "line") and error_item.line is not None:
assert isinstance(error_item.line, int), (
f"errors.data[{i}].line should be int or None, got {type(error_item.line)}"
)
if hasattr(error_item, "message") and error_item.message is not None:
assert isinstance(error_item.message, str), (
f"errors.data[{i}].message should be str or None, got {type(error_item.message)}"
)
if hasattr(error_item, "param") and error_item.param is not None:
assert isinstance(error_item.param, str), (
f"errors.data[{i}].param should be str or None, got {type(error_item.param)}"
)
if hasattr(batch.errors, "object") and batch.errors.object is not None:
assert isinstance(batch.errors.object, str), (
f"errors.object should be str or None, got {type(batch.errors.object)}"
)
assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}"
async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
"""Test successful batch creation and retrieval."""
created_batch = await provider.create_batch(**sample_batch_data)
self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
assert created_batch.id.startswith("batch_")
assert len(created_batch.id) > 13
assert created_batch.object == "batch"
assert created_batch.endpoint == sample_batch_data["endpoint"]
assert created_batch.input_file_id == sample_batch_data["input_file_id"]
assert created_batch.completion_window == sample_batch_data["completion_window"]
assert created_batch.status == "validating"
assert created_batch.metadata == sample_batch_data["metadata"]
assert isinstance(created_batch.created_at, int)
assert created_batch.created_at > 0
retrieved_batch = await provider.retrieve_batch(created_batch.id)
self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
assert retrieved_batch.id == created_batch.id
assert retrieved_batch.input_file_id == created_batch.input_file_id
assert retrieved_batch.endpoint == created_batch.endpoint
assert retrieved_batch.status == created_batch.status
assert retrieved_batch.metadata == created_batch.metadata
async def test_create_batch_without_metadata(self, provider):
"""Test batch creation without optional metadata."""
batch = await provider.create_batch(
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h"
)
assert batch.metadata is None
async def test_create_batch_completion_window(self, provider):
"""Test batch creation with invalid completion window."""
with pytest.raises(ValueError, match="Invalid completion_window"):
await provider.create_batch(
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
)
@pytest.mark.parametrize(
"endpoint",
[
"/v1/embeddings",
"/v1/completions",
"/v1/invalid/endpoint",
"",
],
)
async def test_create_batch_invalid_endpoints(self, provider, endpoint):
"""Test batch creation with various invalid endpoints."""
with pytest.raises(ValueError, match="Invalid endpoint"):
await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
async def test_create_batch_invalid_metadata(self, provider):
"""Test that batch creation fails with invalid metadata."""
with pytest.raises(ValueError, match="should be a valid string"):
await provider.create_batch(
input_file_id="file_123",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={123: "invalid_key"}, # Non-string key
)
with pytest.raises(ValueError, match="should be a valid string"):
await provider.create_batch(
input_file_id="file_123",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={"valid_key": 456}, # Non-string value
)
async def test_retrieve_batch_not_found(self, provider):
"""Test error when retrieving non-existent batch."""
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
await provider.retrieve_batch("nonexistent_batch")
async def test_cancel_batch_success(self, provider, sample_batch_data):
"""Test successful batch cancellation."""
created_batch = await provider.create_batch(**sample_batch_data)
assert created_batch.status == "validating"
cancelled_batch = await provider.cancel_batch(created_batch.id)
assert cancelled_batch.id == created_batch.id
assert cancelled_batch.status in ["cancelling", "cancelled"]
assert isinstance(cancelled_batch.cancelling_at, int)
assert cancelled_batch.cancelling_at >= created_batch.created_at
@pytest.mark.parametrize("status", ["failed", "expired", "completed"])
async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
"""Test error when cancelling batch in final states."""
provider.process_batches = False
created_batch = await provider.create_batch(**sample_batch_data)
# directly update status in kvstore
await provider._update_batch(created_batch.id, status=status)
with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
await provider.cancel_batch(created_batch.id)
async def test_cancel_batch_not_found(self, provider):
"""Test error when cancelling non-existent batch."""
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
await provider.cancel_batch("nonexistent_batch")
async def test_list_batches_empty(self, provider):
"""Test listing batches when none exist."""
response = await provider.list_batches()
assert response.object == "list"
assert response.data == []
assert response.first_id is None
assert response.last_id is None
assert response.has_more is False
async def test_list_batches_single_batch(self, provider, sample_batch_data):
"""Test listing batches with single batch."""
created_batch = await provider.create_batch(**sample_batch_data)
response = await provider.list_batches()
assert len(response.data) == 1
self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
assert response.data[0].id == created_batch.id
assert response.first_id == created_batch.id
assert response.last_id == created_batch.id
assert response.has_more is False
async def test_list_batches_multiple_batches(self, provider):
"""Test listing multiple batches."""
batches = [
await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
)
for i in range(3)
]
response = await provider.list_batches()
assert len(response.data) == 3
batch_ids = {batch.id for batch in response.data}
expected_ids = {batch.id for batch in batches}
assert batch_ids == expected_ids
assert response.has_more is False
assert response.first_id in expected_ids
assert response.last_id in expected_ids
async def test_list_batches_with_limit(self, provider):
"""Test listing batches with limit parameter."""
batches = [
await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
)
for i in range(3)
]
response = await provider.list_batches(limit=2)
assert len(response.data) == 2
assert response.has_more is True
assert response.first_id == response.data[0].id
assert response.last_id == response.data[1].id
batch_ids = {batch.id for batch in response.data}
expected_ids = {batch.id for batch in batches}
assert batch_ids.issubset(expected_ids)
async def test_list_batches_with_pagination(self, provider):
"""Test listing batches with pagination using 'after' parameter."""
for i in range(3):
await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
)
# Get first page
first_page = await provider.list_batches(limit=1)
assert len(first_page.data) == 1
assert first_page.has_more is True
# Get second page using 'after'
second_page = await provider.list_batches(limit=1, after=first_page.data[0].id)
assert len(second_page.data) == 1
assert second_page.data[0].id != first_page.data[0].id
# Verify we got the next batch in order
all_batches = await provider.list_batches()
expected_second_batch_id = all_batches.data[1].id
assert second_page.data[0].id == expected_second_batch_id
async def test_list_batches_invalid_after(self, provider, sample_batch_data):
"""Test listing batches with invalid 'after' parameter."""
await provider.create_batch(**sample_batch_data)
response = await provider.list_batches(after="nonexistent_batch")
# Should return all batches (no filtering when 'after' batch not found)
assert len(response.data) == 1
async def test_kvstore_persistence(self, provider, sample_batch_data):
"""Test that batches are properly persisted in kvstore."""
batch = await provider.create_batch(**sample_batch_data)
stored_data = await provider.kvstore.get(f"batch:{batch.id}")
assert stored_data is not None
stored_batch_dict = json.loads(stored_data)
assert stored_batch_dict["id"] == batch.id
assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"]
async def test_validate_input_file_not_found(self, provider):
"""Test _validate_input when input file does not exist."""
provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found"))
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="nonexistent_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_request"
assert errors[0].message == "Cannot find file nonexistent_file."
assert errors[0].param == "input_file_id"
assert errors[0].line is None
async def test_validate_input_file_exists_empty_content(self, provider):
"""Test _validate_input when file exists but is empty."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b""
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="empty_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 0
assert len(requests) == 0
async def test_validate_input_file_mixed_valid_invalid_json(self, provider):
"""Test _validate_input when file contains valid and invalid JSON lines."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
# Line 1: valid JSON with proper body args, Line 2: invalid JSON
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="mixed_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
# Should have 1 JSON parsing error from line 2, and 1 valid request from line 1
assert len(errors) == 1
assert len(requests) == 1
assert errors[0].code == "invalid_json_line"
assert errors[0].line == 2
assert errors[0].message == "This line is not parseable as valid JSON."
assert requests[0].custom_id == "req-1"
assert requests[0].method == "POST"
assert requests[0].url == "/v1/chat/completions"
assert requests[0].body["model"] == "test-model"
assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}]
async def test_validate_input_invalid_model(self, provider):
"""Test _validate_input when file contains request with non-existent model."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found"))
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="invalid_model_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "model_not_found"
assert errors[0].line == 1
assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported"
assert errors[0].param == "body.model"
@pytest.mark.parametrize(
"param_name,param_path,error_code,error_message",
[
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
("model", "body.model", "invalid_request", "Model parameter is required"),
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
],
)
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
"""Test _validate_input when file contains request with missing required parameters."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
base_request = {
"custom_id": "req-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
}
# Remove the specific parameter being tested
if "." in param_path:
top_level, nested_param = param_path.split(".", 1)
del base_request[top_level][nested_param]
else:
del base_request[param_name]
mock_response.body = json.dumps(base_request).encode()
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id=f"missing_{param_name}_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == error_code
assert errors[0].line == 1
assert errors[0].message == error_message
assert errors[0].param == param_path
async def test_validate_input_url_mismatch(self, provider):
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions", # This doesn't match the URL in the request
input_file_id="url_mismatch_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_url"
assert errors[0].line == 1
assert errors[0].message == "URL provided for this request does not match the batch endpoint"
assert errors[0].param == "url"
async def test_validate_input_multiple_errors_per_request(self, provider):
"""Test _validate_input when a single request has multiple validation errors."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
# Request missing custom_id, has invalid URL, and missing model in body
mock_response.body = (
b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}'
)
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions", # Doesn't match /v1/embeddings in request
input_file_id="multiple_errors_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) >= 2 # At least missing custom_id and URL mismatch
assert len(requests) == 0
for error in errors:
assert error.line == 1
error_codes = {error.code for error in errors}
assert "missing_required_parameter" in error_codes # missing custom_id
assert "invalid_url" in error_codes # URL mismatch
async def test_validate_input_invalid_request_format(self, provider):
"""Test _validate_input when file contains non-object JSON (array, string, number)."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b'["not", "a", "request", "object"]'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="invalid_format_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_request"
assert errors[0].line == 1
assert errors[0].message == "Each line must be a JSON dictionary object"
@pytest.mark.parametrize(
"param_name,param_path,invalid_value,error_message",
[
("custom_id", "custom_id", 12345, "Custom_id must be a string"),
("url", "url", 123, "URL must be a string"),
("method", "method", ["POST"], "Method must be a string"),
("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"),
("model", "body.model", 123, "Model must be a string"),
("messages", "body.messages", "invalid messages format", "Messages must be an array"),
],
)
async def test_validate_input_invalid_parameter_types(
self, provider, param_name, param_path, invalid_value, error_message
):
"""Test _validate_input when file contains request with parameters that have invalid types."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
base_request = {
"custom_id": "req-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
}
# Override the specific parameter with invalid value
if "." in param_path:
top_level, nested_param = param_path.split(".", 1)
base_request[top_level][nested_param] = invalid_value
else:
base_request[param_name] = invalid_value
mock_response.body = json.dumps(base_request).encode()
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id=f"invalid_{param_name}_type_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_request"
assert errors[0].line == 1
assert errors[0].message == error_message
assert errors[0].param == param_path
async def test_max_concurrent_batches(self, provider):
"""Test max_concurrent_batches configuration and concurrency control."""
import asyncio
provider._batch_semaphore = asyncio.Semaphore(2)
provider.process_batches = True # enable because we're testing background processing
active_batches = 0
async def add_and_wait(batch_id: str):
nonlocal active_batches
active_batches += 1
await asyncio.sleep(float("inf"))
# the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl,
# so we can replace _process_batch_impl with our mock to control concurrency
provider._process_batch_impl = add_and_wait
for _ in range(3):
await provider.create_batch(
input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h"
)
await asyncio.sleep(0.042) # let tasks start
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"

View file

@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Tests for idempotency functionality in the reference batches provider.
This module tests the optional idempotency feature that allows clients to provide
an idempotency key (idempotency_key) to ensure that repeated requests with the same key
and parameters return the same batch, while requests with the same key but different
parameters result in a conflict error.
Test Categories:
1. Core Idempotency: Same parameters with same key return same batch
2. Parameter Independence: Different parameters without keys create different batches
3. Conflict Detection: Same key with different parameters raises ConflictError
Tests by Category:
1. Core Idempotency:
- test_idempotent_batch_creation_same_params
- test_idempotent_batch_creation_metadata_order_independence
2. Parameter Independence:
- test_non_idempotent_behavior_without_key
- test_different_idempotency_keys_create_different_batches
3. Conflict Detection:
- test_same_idempotency_key_different_params_conflict (parametrized: input_file_id, metadata values, metadata None vs {})
Key Behaviors Tested:
- Idempotent batch creation when idempotency_key provided with identical parameters
- Metadata order independence for consistent batch ID generation
- Non-idempotent behavior when no idempotency_key provided (random UUIDs)
- Conflict detection for parameter mismatches with same idempotency key
- Deterministic ID generation based solely on idempotency key
- Proper error handling with detailed conflict messages including key and error codes
- Protection against idempotency key reuse with different request parameters
"""
import asyncio
import pytest
from llama_stack.apis.common.errors import ConflictError
class TestReferenceBatchesIdempotency:
"""Test suite for idempotency functionality in the reference implementation."""
async def test_idempotent_batch_creation_same_params(self, provider, sample_batch_data):
"""Test that creating batches with identical parameters returns the same batch when idempotency_key is provided."""
del sample_batch_data["metadata"]
batch1 = await provider.create_batch(
**sample_batch_data,
metadata={"test": "value1", "other": "value2"},
idempotency_key="unique-token-1",
)
# sleep for 1 second to allow created_at timestamps to be different
await asyncio.sleep(1)
batch2 = await provider.create_batch(
**sample_batch_data,
metadata={"other": "value2", "test": "value1"}, # Different order
idempotency_key="unique-token-1",
)
assert batch1.id == batch2.id
assert batch1.input_file_id == batch2.input_file_id
assert batch1.metadata == batch2.metadata
assert batch1.created_at == batch2.created_at
async def test_different_idempotency_keys_create_different_batches(self, provider, sample_batch_data):
"""Test that different idempotency keys create different batches even with same params."""
batch1 = await provider.create_batch(
**sample_batch_data,
idempotency_key="token-A",
)
batch2 = await provider.create_batch(
**sample_batch_data,
idempotency_key="token-B",
)
assert batch1.id != batch2.id
async def test_non_idempotent_behavior_without_key(self, provider, sample_batch_data):
"""Test that batches without idempotency key create unique batches even with identical parameters."""
batch1 = await provider.create_batch(**sample_batch_data)
batch2 = await provider.create_batch(**sample_batch_data)
assert batch1.id != batch2.id
assert batch1.input_file_id == batch2.input_file_id
assert batch1.endpoint == batch2.endpoint
assert batch1.completion_window == batch2.completion_window
assert batch1.metadata == batch2.metadata
@pytest.mark.parametrize(
"param_name,first_value,second_value",
[
("input_file_id", "file_001", "file_002"),
("metadata", {"test": "value1"}, {"test": "value2"}),
("metadata", None, {}),
],
)
async def test_same_idempotency_key_different_params_conflict(
self, provider, sample_batch_data, param_name, first_value, second_value
):
"""Test that same idempotency_key with different parameters raises conflict error."""
sample_batch_data["idempotency_key"] = "same-token"
sample_batch_data[param_name] = first_value
batch1 = await provider.create_batch(**sample_batch_data)
with pytest.raises(ConflictError, match="Idempotency key.*was previously used with different parameters"):
sample_batch_data[param_name] = second_value
await provider.create_batch(**sample_batch_data)
retrieved_batch = await provider.retrieve_batch(batch1.id)
assert retrieved_batch.id == batch1.id
assert getattr(retrieved_batch, param_name) == first_value

View file

@ -0,0 +1,251 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import patch
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_aws
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack.providers.remote.files.s3 import (
S3FilesImplConfig,
get_adapter_impl,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
class MockUploadFile:
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
self.content = content
self.filename = filename
self.content_type = content_type
async def read(self):
return self.content
@pytest.fixture
def s3_config(tmp_path):
db_path = tmp_path / "s3_files_metadata.db"
return S3FilesImplConfig(
bucket_name="test-bucket",
region="not-a-region",
auto_create_bucket=True,
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
)
@pytest.fixture
def s3_client():
"""Create a mocked S3 client for testing."""
# we use `with mock_aws()` because @mock_aws decorator does not support being a generator
with mock_aws():
# must yield or the mock will be reset before it is used
yield boto3.client("s3")
@pytest.fixture
async def s3_provider(s3_config, s3_client):
"""Create an S3 files provider with mocked S3 for testing."""
provider = await get_adapter_impl(s3_config, {})
yield provider
await provider.shutdown()
@pytest.fixture
def sample_text_file():
content = b"Hello, this is a test file for the S3 Files API!"
return MockUploadFile(content, "sample_text_file.txt")
class TestS3FilesImpl:
"""Test suite for S3 Files implementation."""
async def test_upload_file(self, s3_provider, sample_text_file, s3_client, s3_config):
"""Test successful file upload."""
sample_text_file.filename = "test_upload_file"
result = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
assert result.filename == sample_text_file.filename
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
assert result.bytes == len(sample_text_file.content)
assert result.id.startswith("file-")
# Verify file exists in S3 backend
response = s3_client.head_object(Bucket=s3_config.bucket_name, Key=result.id)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
async def test_list_files_empty(self, s3_provider):
"""Test listing files when no files exist."""
result = await s3_provider.openai_list_files()
assert len(result.data) == 0
assert not result.has_more
assert result.first_id == ""
assert result.last_id == ""
async def test_retrieve_file(self, s3_provider, sample_text_file):
"""Test retrieving file metadata."""
sample_text_file.filename = "test_retrieve_file"
uploaded = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
retrieved = await s3_provider.openai_retrieve_file(uploaded.id)
assert retrieved.id == uploaded.id
assert retrieved.filename == uploaded.filename
assert retrieved.purpose == uploaded.purpose
assert retrieved.bytes == uploaded.bytes
async def test_retrieve_file_content(self, s3_provider, sample_text_file):
"""Test retrieving file content."""
sample_text_file.filename = "test_retrieve_file_content"
uploaded = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
response = await s3_provider.openai_retrieve_file_content(uploaded.id)
assert response.body == sample_text_file.content
assert response.headers["Content-Disposition"] == f'attachment; filename="{sample_text_file.filename}"'
async def test_delete_file(self, s3_provider, sample_text_file, s3_config, s3_client):
"""Test deleting a file."""
sample_text_file.filename = "test_delete_file"
uploaded = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
delete_response = await s3_provider.openai_delete_file(uploaded.id)
assert delete_response.id == uploaded.id
assert delete_response.deleted is True
with pytest.raises(ResourceNotFoundError, match="not found"):
await s3_provider.openai_retrieve_file(uploaded.id)
# Verify file is gone from S3 backend
with pytest.raises(ClientError) as exc_info:
s3_client.head_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
assert exc_info.value.response["Error"]["Code"] == "404"
async def test_list_files(self, s3_provider, sample_text_file):
"""Test listing files after uploading some."""
sample_text_file.filename = "test_list_files_with_content_file1"
file1 = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
file2_content = MockUploadFile(b"Second file content", "test_list_files_with_content_file2")
file2 = await s3_provider.openai_upload_file(
file=file2_content,
purpose=OpenAIFilePurpose.BATCH,
)
result = await s3_provider.openai_list_files()
assert len(result.data) == 2
file_ids = {f.id for f in result.data}
assert file1.id in file_ids
assert file2.id in file_ids
async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file):
"""Test listing files with purpose filter."""
sample_text_file.filename = "test_list_files_with_purpose_filter_file1"
file1 = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
file2_content = MockUploadFile(b"Batch file content", "test_list_files_with_purpose_filter_file2")
await s3_provider.openai_upload_file(
file=file2_content,
purpose=OpenAIFilePurpose.BATCH,
)
result = await s3_provider.openai_list_files(purpose=OpenAIFilePurpose.ASSISTANTS)
assert len(result.data) == 1
assert result.data[0].id == file1.id
assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS
async def test_nonexistent_file_retrieval(self, s3_provider):
"""Test retrieving a non-existent file raises error."""
with pytest.raises(ResourceNotFoundError, match="not found"):
await s3_provider.openai_retrieve_file("file-nonexistent")
async def test_nonexistent_file_content_retrieval(self, s3_provider):
"""Test retrieving content of a non-existent file raises error."""
with pytest.raises(ResourceNotFoundError, match="not found"):
await s3_provider.openai_retrieve_file_content("file-nonexistent")
async def test_nonexistent_file_deletion(self, s3_provider):
"""Test deleting a non-existent file raises error."""
with pytest.raises(ResourceNotFoundError, match="not found"):
await s3_provider.openai_delete_file("file-nonexistent")
async def test_upload_file_without_filename(self, s3_provider, sample_text_file):
"""Test uploading a file without a filename uses the fallback."""
del sample_text_file.filename
result = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
assert result.bytes == len(sample_text_file.content)
retrieved = await s3_provider.openai_retrieve_file(result.id)
assert retrieved.filename == result.filename
async def test_file_operations_when_s3_object_deleted(self, s3_provider, sample_text_file, s3_config, s3_client):
"""Test file operations when S3 object is deleted but metadata exists (negative test)."""
sample_text_file.filename = "test_orphaned_metadata"
uploaded = await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
# Directly delete the S3 object from the backend
s3_client.delete_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
with pytest.raises(ResourceNotFoundError, match="not found") as exc_info:
await s3_provider.openai_retrieve_file_content(uploaded.id)
assert uploaded.id in str(exc_info).lower()
listed_files = await s3_provider.openai_list_files()
assert uploaded.id not in [file.id for file in listed_files.data]
async def test_upload_file_s3_put_object_failure(self, s3_provider, sample_text_file, s3_config, s3_client):
"""Test that put_object failure results in exception and no orphaned metadata."""
sample_text_file.filename = "test_s3_put_object_failure"
def failing_put_object(*args, **kwargs):
raise ClientError(
error_response={"Error": {"Code": "SolarRadiation", "Message": "Bloop"}}, operation_name="PutObject"
)
with patch.object(s3_provider.client, "put_object", side_effect=failing_put_object):
with pytest.raises(RuntimeError, match="Failed to upload file to S3"):
await s3_provider.openai_upload_file(
file=sample_text_file,
purpose=OpenAIFilePurpose.ASSISTANTS,
)
files_list = await s3_provider.openai_list_files()
assert len(files_list.data) == 0, "No file metadata should remain after failed upload"

View file

@ -6,7 +6,7 @@
import asyncio
import json
import logging
import logging # allow-direct-logging
import threading
import time
from http.server import BaseHTTPRequestHandler, HTTPServer

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
convert_message_to_openai_dict_new,
openai_messages_to_messages,
)
@ -182,3 +183,42 @@ def test_user_message_accepts_images():
assert len(msg.content) == 2
assert msg.content[0].text == "Describe this image:"
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
async def test_convert_message_to_openai_dict_new_user_message():
"""Test convert_message_to_openai_dict_new with UserMessage."""
message = UserMessage(content="Hello, world!", role="user")
result = await convert_message_to_openai_dict_new(message)
assert result["role"] == "user"
assert result["content"] == "Hello, world!"
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
message = CompletionMessage(
content="I'll help you find the weather.",
tool_calls=[
ToolCall(
call_id="call_123",
tool_name="get_weather",
arguments={"city": "Sligo"},
arguments_json='{"city": "Sligo"}',
)
],
stop_reason=StopReason.end_of_turn,
)
result = await convert_message_to_openai_dict_new(message)
# This would have failed with "Cannot instantiate typing.Union" before the fix
assert result["role"] == "assistant"
assert result["content"] == "I'll help you find the weather."
assert "tool_calls" in result
assert result["tool_calls"] is not None
assert len(result["tool_calls"]) == 1
tool_call = result["tool_calls"][0]
assert tool_call.id == "call_123"
assert tool_call.type == "function"
assert tool_call.function.name == "get_weather"
assert tool_call.function.arguments == '{"city": "Sligo"}'

View file

@ -0,0 +1,105 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.core.datatypes import CORSConfig, process_cors_config
def test_cors_config_defaults():
config = CORSConfig()
assert config.allow_origins == []
assert config.allow_origin_regex is None
assert config.allow_methods == ["OPTIONS"]
assert config.allow_headers == []
assert config.allow_credentials is False
assert config.expose_headers == []
assert config.max_age == 600
def test_cors_config_explicit_config():
config = CORSConfig(
allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"]
)
assert config.allow_origins == ["https://example.com"]
assert config.allow_credentials is True
assert config.max_age == 3600
assert config.allow_methods == ["GET", "POST"]
def test_cors_config_regex():
config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
assert config.allow_origins == []
assert config.allow_origin_regex == r"https?://localhost:\d+"
def test_cors_config_wildcard_credentials_error():
with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
CORSConfig(allow_origins=["*"], allow_credentials=True)
with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
CORSConfig(allow_origins=["https://example.com", "*"], allow_credentials=True)
def test_process_cors_config_false():
result = process_cors_config(False)
assert result is None
def test_process_cors_config_true():
result = process_cors_config(True)
assert isinstance(result, CORSConfig)
assert result.allow_origins == []
assert result.allow_origin_regex == r"https?://localhost:\d+"
assert result.allow_credentials is False
expected_methods = ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
for method in expected_methods:
assert method in result.allow_methods
def test_process_cors_config_passthrough():
original = CORSConfig(allow_origins=["https://example.com"], allow_methods=["GET"])
result = process_cors_config(original)
assert result is original
def test_process_cors_config_invalid_type():
with pytest.raises(ValueError, match="Expected bool or CORSConfig, got str"):
process_cors_config("invalid")
def test_cors_config_model_dump():
cors_config = CORSConfig(
allow_origins=["https://example.com"],
allow_methods=["GET", "POST"],
allow_headers=["Content-Type"],
allow_credentials=True,
max_age=3600,
)
config_dict = cors_config.model_dump()
assert config_dict["allow_origins"] == ["https://example.com"]
assert config_dict["allow_methods"] == ["GET", "POST"]
assert config_dict["allow_headers"] == ["Content-Type"]
assert config_dict["allow_credentials"] is True
assert config_dict["max_age"] == 3600
expected_keys = {
"allow_origins",
"allow_origin_regex",
"allow_methods",
"allow_headers",
"allow_credentials",
"expose_headers",
"max_age",
}
assert set(config_dict.keys()) == expected_keys