mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 21:48:36 +00:00
Merge branch 'main' into content-extension
This commit is contained in:
commit
3e11e1472c
334 changed files with 22841 additions and 8940 deletions
|
@ -5,86 +5,121 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Unit tests for LlamaStackAsLibraryClient initialization error handling.
|
||||
Unit tests for LlamaStackAsLibraryClient automatic initialization.
|
||||
|
||||
These tests ensure that users get proper error messages when they forget to call
|
||||
initialize() on the library client, preventing AttributeError regressions.
|
||||
These tests ensure that the library client is automatically initialized
|
||||
and ready to use immediately after construction.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.library_client import (
|
||||
AsyncLlamaStackAsLibraryClient,
|
||||
LlamaStackAsLibraryClient,
|
||||
)
|
||||
from llama_stack.core.server.routes import RouteImpls
|
||||
|
||||
|
||||
class TestLlamaStackAsLibraryClientInitialization:
|
||||
"""Test proper error handling for uninitialized library clients."""
|
||||
class TestLlamaStackAsLibraryClientAutoInitialization:
|
||||
"""Test automatic initialization of library clients."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_call",
|
||||
[
|
||||
lambda client: client.models.list(),
|
||||
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
|
||||
lambda client: next(
|
||||
client.chat.completions.create(
|
||||
model="test", messages=[{"role": "user", "content": "test"}], stream=True
|
||||
)
|
||||
),
|
||||
],
|
||||
ids=["models.list", "chat.completions.create", "chat.completions.create_stream"],
|
||||
)
|
||||
def test_sync_client_proper_error_without_initialization(self, api_call):
|
||||
"""Test that sync client raises ValueError with helpful message when not initialized."""
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
def test_sync_client_auto_initialization(self, monkeypatch):
|
||||
"""Test that sync client is automatically initialized after construction."""
|
||||
# Mock the stack construction to avoid dependency issues
|
||||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
api_call(client)
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Client not initialized" in error_msg
|
||||
assert "Please call initialize() first" in error_msg
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_call",
|
||||
[
|
||||
lambda client: client.models.list(),
|
||||
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
|
||||
],
|
||||
ids=["models.list", "chat.completions.create"],
|
||||
)
|
||||
async def test_async_client_proper_error_without_initialization(self, api_call):
|
||||
"""Test that async client raises ValueError with helpful message when not initialized."""
|
||||
client = AsyncLlamaStackAsLibraryClient("nvidia")
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await api_call(client)
|
||||
client = LlamaStackAsLibraryClient("ci-tests")
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Client not initialized" in error_msg
|
||||
assert "Please call initialize() first" in error_msg
|
||||
assert client.async_client.route_impls is not None
|
||||
|
||||
async def test_async_client_streaming_error_without_initialization(self):
|
||||
"""Test that async client streaming raises ValueError with helpful message when not initialized."""
|
||||
client = AsyncLlamaStackAsLibraryClient("nvidia")
|
||||
async def test_async_client_auto_initialization(self, monkeypatch):
|
||||
"""Test that async client can be initialized and works properly."""
|
||||
# Mock the stack construction to avoid dependency issues
|
||||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
stream = await client.chat.completions.create(
|
||||
model="test", messages=[{"role": "user", "content": "test"}], stream=True
|
||||
)
|
||||
await anext(stream)
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "Client not initialized" in error_msg
|
||||
assert "Please call initialize() first" in error_msg
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
def test_route_impls_initialized_to_none(self):
|
||||
"""Test that route_impls is initialized to None to prevent AttributeError."""
|
||||
# Test sync client
|
||||
sync_client = LlamaStackAsLibraryClient("nvidia")
|
||||
assert sync_client.async_client.route_impls is None
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
# Test async client directly
|
||||
async_client = AsyncLlamaStackAsLibraryClient("nvidia")
|
||||
assert async_client.route_impls is None
|
||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||
|
||||
# Initialize the client
|
||||
result = await client.initialize()
|
||||
assert result is True
|
||||
assert client.route_impls is not None
|
||||
|
||||
def test_initialize_method_backward_compatibility(self, monkeypatch):
|
||||
"""Test that initialize() method still works for backward compatibility."""
|
||||
# Mock the stack construction to avoid dependency issues
|
||||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = LlamaStackAsLibraryClient("ci-tests")
|
||||
|
||||
result = client.initialize()
|
||||
assert result is None
|
||||
|
||||
result2 = client.initialize()
|
||||
assert result2 is None
|
||||
|
||||
async def test_async_initialize_method_idempotent(self, monkeypatch):
|
||||
"""Test that async initialize() method can be called multiple times safely."""
|
||||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
client = AsyncLlamaStackAsLibraryClient("ci-tests")
|
||||
|
||||
result1 = await client.initialize()
|
||||
assert result1 is True
|
||||
|
||||
result2 = await client.initialize()
|
||||
assert result2 is True
|
||||
|
||||
def test_route_impls_automatically_set(self, monkeypatch):
|
||||
"""Test that route_impls is automatically set during construction."""
|
||||
mock_impls = {}
|
||||
mock_route_impls = RouteImpls({})
|
||||
|
||||
async def mock_construct_stack(config, custom_provider_registry):
|
||||
return mock_impls
|
||||
|
||||
def mock_initialize_route_impls(impls):
|
||||
return mock_route_impls
|
||||
|
||||
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
|
||||
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
|
||||
|
||||
sync_client = LlamaStackAsLibraryClient("ci-tests")
|
||||
assert sync_client.async_client.route_impls is not None
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
|
@ -190,7 +191,7 @@ class TestOpenAIFilesAPI:
|
|||
|
||||
async def test_retrieve_file_not_found(self, files_provider):
|
||||
"""Test retrieving a non-existent file."""
|
||||
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await files_provider.openai_retrieve_file("file-nonexistent")
|
||||
|
||||
async def test_retrieve_file_content_success(self, files_provider, sample_text_file):
|
||||
|
@ -208,7 +209,7 @@ class TestOpenAIFilesAPI:
|
|||
|
||||
async def test_retrieve_file_content_not_found(self, files_provider):
|
||||
"""Test retrieving content of a non-existent file."""
|
||||
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await files_provider.openai_retrieve_file_content("file-nonexistent")
|
||||
|
||||
async def test_delete_file_success(self, files_provider, sample_text_file):
|
||||
|
@ -229,12 +230,12 @@ class TestOpenAIFilesAPI:
|
|||
assert delete_response.deleted is True
|
||||
|
||||
# Verify file no longer exists
|
||||
with pytest.raises(ValueError, match=f"File with id {uploaded_file.id} not found"):
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await files_provider.openai_retrieve_file(uploaded_file.id)
|
||||
|
||||
async def test_delete_file_not_found(self, files_provider):
|
||||
"""Test deleting a non-existent file."""
|
||||
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await files_provider.openai_delete_file("file-nonexistent")
|
||||
|
||||
async def test_file_persistence_across_operations(self, files_provider, sample_text_file):
|
||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseMessage,
|
||||
OpenAIResponseObjectWithInput,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
|
@ -41,7 +42,7 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
@ -136,9 +137,12 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
input=input_text,
|
||||
model=model,
|
||||
temperature=0.1,
|
||||
stream=True, # Enable streaming to test content part events
|
||||
)
|
||||
|
||||
# Verify
|
||||
# For streaming response, collect all chunks
|
||||
chunks = [chunk async for chunk in result]
|
||||
|
||||
mock_inference_api.openai_chat_completion.assert_called_once_with(
|
||||
model=model,
|
||||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||
|
@ -147,11 +151,32 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
stream=True,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
# Should have content part events for text streaming
|
||||
# Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed
|
||||
assert len(chunks) >= 4
|
||||
assert chunks[0].type == "response.created"
|
||||
|
||||
# Check for content part events
|
||||
content_part_added_events = [c for c in chunks if c.type == "response.content_part.added"]
|
||||
content_part_done_events = [c for c in chunks if c.type == "response.content_part.done"]
|
||||
text_delta_events = [c for c in chunks if c.type == "response.output_text.delta"]
|
||||
|
||||
assert len(content_part_added_events) >= 1, "Should have content_part.added event for text"
|
||||
assert len(content_part_done_events) >= 1, "Should have content_part.done event for text"
|
||||
assert len(text_delta_events) >= 1, "Should have text delta events"
|
||||
|
||||
# Verify final event is completion
|
||||
assert chunks[-1].type == "response.completed"
|
||||
|
||||
# When streaming, the final response is in the last chunk
|
||||
final_response = chunks[-1].response
|
||||
assert final_response.model == model
|
||||
assert len(final_response.output) == 1
|
||||
assert isinstance(final_response.output[0], OpenAIResponseMessage)
|
||||
|
||||
openai_responses_impl.responses_store.store_response_object.assert_called_once()
|
||||
assert result.model == model
|
||||
assert len(result.output) == 1
|
||||
assert isinstance(result.output[0], OpenAIResponseMessage)
|
||||
assert result.output[0].content[0].text == "Dublin"
|
||||
assert final_response.output[0].content[0].text == "Dublin"
|
||||
|
||||
|
||||
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
|
||||
|
@ -272,6 +297,8 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
|||
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
|
||||
# Verify event types
|
||||
# Should have: response.created, output_item.added, function_call_arguments.delta,
|
||||
# function_call_arguments.done, output_item.done, response.completed
|
||||
assert len(chunks) == 6
|
||||
|
@ -435,6 +462,53 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
|
|||
assert input[3].content == "fake_input"
|
||||
|
||||
|
||||
async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mock_responses_store):
|
||||
"""Test prepending a previous response which included an mcp tool call to a new response."""
|
||||
input_item_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")],
|
||||
role="user",
|
||||
)
|
||||
output_tool_call = OpenAIResponseOutputMessageMCPCall(
|
||||
id="ws_123",
|
||||
name="fake-tool",
|
||||
arguments="fake-arguments",
|
||||
server_label="fake-label",
|
||||
)
|
||||
output_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text="fake_tool_call_response")],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
response = OpenAIResponseObjectWithInput(
|
||||
created_at=1,
|
||||
id="resp_123",
|
||||
model="fake_model",
|
||||
output=[output_tool_call, output_message],
|
||||
status="completed",
|
||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||
input=[input_item_message],
|
||||
)
|
||||
mock_responses_store.get_response_object.return_value = response
|
||||
|
||||
input_messages = [OpenAIResponseMessage(content="fake_input", role="user")]
|
||||
input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123")
|
||||
|
||||
assert len(input) == 4
|
||||
# Check for previous input
|
||||
assert isinstance(input[0], OpenAIResponseMessage)
|
||||
assert input[0].content[0].text == "fake_previous_input"
|
||||
# Check for previous output MCP tool call
|
||||
assert isinstance(input[1], OpenAIResponseOutputMessageMCPCall)
|
||||
# Check for previous output web search response
|
||||
assert isinstance(input[2], OpenAIResponseMessage)
|
||||
assert input[2].content[0].text == "fake_tool_call_response"
|
||||
# Check for new input
|
||||
assert isinstance(input[3], OpenAIResponseMessage)
|
||||
assert input[3].content == "fake_input"
|
||||
|
||||
|
||||
async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api):
|
||||
# Setup
|
||||
input_text = "What is the capital of Ireland?"
|
||||
|
|
|
@ -0,0 +1,342 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatText,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
convert_chat_choice_to_response_message,
|
||||
convert_response_content_to_chat_content,
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
get_message_type_by_role,
|
||||
is_function_tool_call,
|
||||
)
|
||||
|
||||
|
||||
class TestConvertChatChoiceToResponseMessage:
|
||||
async def test_convert_string_content(self):
|
||||
choice = OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(content="Test message"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
|
||||
result = await convert_chat_choice_to_response_message(choice)
|
||||
|
||||
assert result.role == "assistant"
|
||||
assert result.status == "completed"
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], OpenAIResponseOutputMessageContentOutputText)
|
||||
assert result.content[0].text == "Test message"
|
||||
|
||||
async def test_convert_text_param_content(self):
|
||||
choice = OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(
|
||||
content=[OpenAIChatCompletionContentPartTextParam(text="Test text param")]
|
||||
),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await convert_chat_choice_to_response_message(choice)
|
||||
|
||||
assert "does not yet support output content type" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestConvertResponseContentToChatContent:
|
||||
async def test_convert_string_content(self):
|
||||
result = await convert_response_content_to_chat_content("Simple string")
|
||||
assert result == "Simple string"
|
||||
|
||||
async def test_convert_text_content_parts(self):
|
||||
content = [
|
||||
OpenAIResponseInputMessageContentText(text="First part"),
|
||||
OpenAIResponseOutputMessageContentOutputText(text="Second part"),
|
||||
]
|
||||
|
||||
result = await convert_response_content_to_chat_content(content)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], OpenAIChatCompletionContentPartTextParam)
|
||||
assert result[0].text == "First part"
|
||||
assert isinstance(result[1], OpenAIChatCompletionContentPartTextParam)
|
||||
assert result[1].text == "Second part"
|
||||
|
||||
async def test_convert_image_content(self):
|
||||
content = [OpenAIResponseInputMessageContentImage(image_url="https://example.com/image.jpg", detail="high")]
|
||||
|
||||
result = await convert_response_content_to_chat_content(content)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIChatCompletionContentPartImageParam)
|
||||
assert result[0].image_url.url == "https://example.com/image.jpg"
|
||||
assert result[0].image_url.detail == "high"
|
||||
|
||||
|
||||
class TestConvertResponseInputToChatMessages:
|
||||
async def test_convert_string_input(self):
|
||||
result = await convert_response_input_to_chat_messages("User message")
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIUserMessageParam)
|
||||
assert result[0].content == "User message"
|
||||
|
||||
async def test_convert_function_tool_call_output(self):
|
||||
input_items = [
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_123",
|
||||
name="test_function",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
output="Tool output",
|
||||
call_id="call_123",
|
||||
),
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||
assert result[0].tool_calls[0].id == "call_123"
|
||||
assert result[0].tool_calls[0].function.name == "test_function"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
assert isinstance(result[1], OpenAIToolMessageParam)
|
||||
assert result[1].content == "Tool output"
|
||||
assert result[1].tool_call_id == "call_123"
|
||||
|
||||
async def test_convert_function_tool_call(self):
|
||||
input_items = [
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_456",
|
||||
name="test_function",
|
||||
arguments='{"param": "value"}',
|
||||
)
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||
assert len(result[0].tool_calls) == 1
|
||||
assert result[0].tool_calls[0].id == "call_456"
|
||||
assert result[0].tool_calls[0].function.name == "test_function"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
|
||||
async def test_convert_function_call_ordering(self):
|
||||
input_items = [
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_123",
|
||||
name="test_function_a",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_456",
|
||||
name="test_function_b",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
output="AAA",
|
||||
call_id="call_123",
|
||||
),
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
output="BBB",
|
||||
call_id="call_456",
|
||||
),
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
assert len(result) == 4
|
||||
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||
assert len(result[0].tool_calls) == 1
|
||||
assert result[0].tool_calls[0].id == "call_123"
|
||||
assert result[0].tool_calls[0].function.name == "test_function_a"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
assert isinstance(result[1], OpenAIToolMessageParam)
|
||||
assert result[1].content == "AAA"
|
||||
assert result[1].tool_call_id == "call_123"
|
||||
assert isinstance(result[2], OpenAIAssistantMessageParam)
|
||||
assert len(result[2].tool_calls) == 1
|
||||
assert result[2].tool_calls[0].id == "call_456"
|
||||
assert result[2].tool_calls[0].function.name == "test_function_b"
|
||||
assert result[2].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
assert isinstance(result[3], OpenAIToolMessageParam)
|
||||
assert result[3].content == "BBB"
|
||||
assert result[3].tool_call_id == "call_456"
|
||||
|
||||
async def test_convert_response_message(self):
|
||||
input_items = [
|
||||
OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[OpenAIResponseInputMessageContentText(text="User text")],
|
||||
)
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIUserMessageParam)
|
||||
# Content should be converted to chat content format
|
||||
assert len(result[0].content) == 1
|
||||
assert result[0].content[0].text == "User text"
|
||||
|
||||
|
||||
class TestConvertResponseTextToChatResponseFormat:
|
||||
async def test_convert_text_format(self):
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatText)
|
||||
assert result.type == "text"
|
||||
|
||||
async def test_convert_json_object_format(self):
|
||||
text = OpenAIResponseText(format={"type": "json_object"})
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatJSONObject)
|
||||
|
||||
async def test_convert_json_schema_format(self):
|
||||
schema_def = {"type": "object", "properties": {"test": {"type": "string"}}}
|
||||
text = OpenAIResponseText(
|
||||
format={
|
||||
"type": "json_schema",
|
||||
"name": "test_schema",
|
||||
"schema": schema_def,
|
||||
}
|
||||
)
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatJSONSchema)
|
||||
assert result.json_schema["name"] == "test_schema"
|
||||
assert result.json_schema["schema"] == schema_def
|
||||
|
||||
async def test_default_text_format(self):
|
||||
text = OpenAIResponseText()
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatText)
|
||||
assert result.type == "text"
|
||||
|
||||
|
||||
class TestGetMessageTypeByRole:
|
||||
async def test_user_role(self):
|
||||
result = await get_message_type_by_role("user")
|
||||
assert result == OpenAIUserMessageParam
|
||||
|
||||
async def test_system_role(self):
|
||||
result = await get_message_type_by_role("system")
|
||||
assert result == OpenAISystemMessageParam
|
||||
|
||||
async def test_assistant_role(self):
|
||||
result = await get_message_type_by_role("assistant")
|
||||
assert result == OpenAIAssistantMessageParam
|
||||
|
||||
async def test_developer_role(self):
|
||||
result = await get_message_type_by_role("developer")
|
||||
assert result == OpenAIDeveloperMessageParam
|
||||
|
||||
async def test_unknown_role(self):
|
||||
result = await get_message_type_by_role("unknown")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestIsFunctionToolCall:
|
||||
def test_is_function_tool_call_true(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="test_function",
|
||||
arguments="{}",
|
||||
),
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(
|
||||
type="function", name="test_function", parameters={"type": "object", "properties": {}}
|
||||
),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is True
|
||||
|
||||
def test_is_function_tool_call_false_different_name(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="other_function",
|
||||
arguments="{}",
|
||||
),
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(
|
||||
type="function", name="test_function", parameters={"type": "object", "properties": {}}
|
||||
),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
||||
|
||||
def test_is_function_tool_call_false_no_function(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=None,
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(
|
||||
type="function", name="test_function", parameters={"type": "object", "properties": {}}
|
||||
),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
||||
|
||||
def test_is_function_tool_call_false_wrong_type(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="web_search",
|
||||
arguments="{}",
|
||||
),
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
54
tests/unit/providers/batches/conftest.py
Normal file
54
tests/unit/providers/batches/conftest.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Shared fixtures for batches provider unit tests."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
|
||||
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def provider():
|
||||
"""Create a test provider instance with temporary database."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_batches.db"
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
|
||||
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
|
||||
|
||||
# Create kvstore and mock APIs
|
||||
kvstore = await kvstore_impl(config.kvstore)
|
||||
mock_inference = AsyncMock()
|
||||
mock_files = AsyncMock()
|
||||
mock_models = AsyncMock()
|
||||
|
||||
provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
|
||||
await provider.initialize()
|
||||
|
||||
# unit tests should not require background processing
|
||||
provider.process_batches = False
|
||||
|
||||
yield provider
|
||||
|
||||
await provider.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_batch_data():
|
||||
"""Sample batch data for testing."""
|
||||
return {
|
||||
"input_file_id": "file_abc123",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"metadata": {"test": "true", "priority": "high"},
|
||||
}
|
710
tests/unit/providers/batches/test_reference.py
Normal file
710
tests/unit/providers/batches/test_reference.py
Normal file
|
@ -0,0 +1,710 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Test suite for the reference implementation of the Batches API.
|
||||
|
||||
The tests are categorized and outlined below, keep this updated:
|
||||
|
||||
- Batch creation with various parameters and validation:
|
||||
* test_create_and_retrieve_batch_success (positive)
|
||||
* test_create_batch_without_metadata (positive)
|
||||
* test_create_batch_completion_window (negative)
|
||||
* test_create_batch_invalid_endpoints (negative)
|
||||
* test_create_batch_invalid_metadata (negative)
|
||||
|
||||
- Batch retrieval and error handling for non-existent batches:
|
||||
* test_retrieve_batch_not_found (negative)
|
||||
|
||||
- Batch cancellation with proper status transitions:
|
||||
* test_cancel_batch_success (positive)
|
||||
* test_cancel_batch_invalid_statuses (negative)
|
||||
* test_cancel_batch_not_found (negative)
|
||||
|
||||
- Batch listing with pagination and filtering:
|
||||
* test_list_batches_empty (positive)
|
||||
* test_list_batches_single_batch (positive)
|
||||
* test_list_batches_multiple_batches (positive)
|
||||
* test_list_batches_with_limit (positive)
|
||||
* test_list_batches_with_pagination (positive)
|
||||
* test_list_batches_invalid_after (negative)
|
||||
|
||||
- Data persistence in the underlying key-value store:
|
||||
* test_kvstore_persistence (positive)
|
||||
|
||||
- Batch processing concurrency control:
|
||||
* test_max_concurrent_batches (positive)
|
||||
|
||||
- Input validation testing (direct _validate_input method tests):
|
||||
* test_validate_input_file_not_found (negative)
|
||||
* test_validate_input_file_exists_empty_content (positive)
|
||||
* test_validate_input_file_mixed_valid_invalid_json (mixed)
|
||||
* test_validate_input_invalid_model (negative)
|
||||
* test_validate_input_url_mismatch (negative)
|
||||
* test_validate_input_multiple_errors_per_request (negative)
|
||||
* test_validate_input_invalid_request_format (negative)
|
||||
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
|
||||
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
|
||||
|
||||
The tests use temporary SQLite databases for isolation and mock external
|
||||
dependencies like inference, files, and models APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.batches import BatchObject
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
|
||||
|
||||
class TestReferenceBatchesImpl:
|
||||
"""Test the reference implementation of the Batches API."""
|
||||
|
||||
def _validate_batch_type(self, batch, expected_metadata=None):
|
||||
"""
|
||||
Helper function to validate batch object structure and field types.
|
||||
|
||||
Note: This validates the direct BatchObject from the provider, not the
|
||||
client library response which has a different structure.
|
||||
|
||||
Args:
|
||||
batch: The BatchObject instance to validate.
|
||||
expected_metadata: Optional expected metadata dictionary to validate against.
|
||||
"""
|
||||
assert isinstance(batch.id, str)
|
||||
assert isinstance(batch.completion_window, str)
|
||||
assert isinstance(batch.created_at, int)
|
||||
assert isinstance(batch.endpoint, str)
|
||||
assert isinstance(batch.input_file_id, str)
|
||||
assert batch.object == "batch"
|
||||
assert batch.status in [
|
||||
"validating",
|
||||
"failed",
|
||||
"in_progress",
|
||||
"finalizing",
|
||||
"completed",
|
||||
"expired",
|
||||
"cancelling",
|
||||
"cancelled",
|
||||
]
|
||||
|
||||
if expected_metadata is not None:
|
||||
assert batch.metadata == expected_metadata
|
||||
|
||||
timestamp_fields = [
|
||||
"cancelled_at",
|
||||
"cancelling_at",
|
||||
"completed_at",
|
||||
"expired_at",
|
||||
"expires_at",
|
||||
"failed_at",
|
||||
"finalizing_at",
|
||||
"in_progress_at",
|
||||
]
|
||||
for field in timestamp_fields:
|
||||
field_value = getattr(batch, field, None)
|
||||
if field_value is not None:
|
||||
assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}"
|
||||
|
||||
file_id_fields = ["error_file_id", "output_file_id"]
|
||||
for field in file_id_fields:
|
||||
field_value = getattr(batch, field, None)
|
||||
if field_value is not None:
|
||||
assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}"
|
||||
|
||||
if hasattr(batch, "request_counts") and batch.request_counts is not None:
|
||||
assert isinstance(batch.request_counts.completed, int), (
|
||||
f"request_counts.completed should be int, got {type(batch.request_counts.completed)}"
|
||||
)
|
||||
assert isinstance(batch.request_counts.failed, int), (
|
||||
f"request_counts.failed should be int, got {type(batch.request_counts.failed)}"
|
||||
)
|
||||
assert isinstance(batch.request_counts.total, int), (
|
||||
f"request_counts.total should be int, got {type(batch.request_counts.total)}"
|
||||
)
|
||||
|
||||
if hasattr(batch, "errors") and batch.errors is not None:
|
||||
assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}"
|
||||
|
||||
if hasattr(batch.errors, "data") and batch.errors.data is not None:
|
||||
assert isinstance(batch.errors.data, list), (
|
||||
f"errors.data should be list or None, got {type(batch.errors.data)}"
|
||||
)
|
||||
|
||||
for i, error_item in enumerate(batch.errors.data):
|
||||
assert isinstance(error_item, dict), (
|
||||
f"errors.data[{i}] should be object or dict, got {type(error_item)}"
|
||||
)
|
||||
|
||||
if hasattr(error_item, "code") and error_item.code is not None:
|
||||
assert isinstance(error_item.code, str), (
|
||||
f"errors.data[{i}].code should be str or None, got {type(error_item.code)}"
|
||||
)
|
||||
|
||||
if hasattr(error_item, "line") and error_item.line is not None:
|
||||
assert isinstance(error_item.line, int), (
|
||||
f"errors.data[{i}].line should be int or None, got {type(error_item.line)}"
|
||||
)
|
||||
|
||||
if hasattr(error_item, "message") and error_item.message is not None:
|
||||
assert isinstance(error_item.message, str), (
|
||||
f"errors.data[{i}].message should be str or None, got {type(error_item.message)}"
|
||||
)
|
||||
|
||||
if hasattr(error_item, "param") and error_item.param is not None:
|
||||
assert isinstance(error_item.param, str), (
|
||||
f"errors.data[{i}].param should be str or None, got {type(error_item.param)}"
|
||||
)
|
||||
|
||||
if hasattr(batch.errors, "object") and batch.errors.object is not None:
|
||||
assert isinstance(batch.errors.object, str), (
|
||||
f"errors.object should be str or None, got {type(batch.errors.object)}"
|
||||
)
|
||||
assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}"
|
||||
|
||||
async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
|
||||
"""Test successful batch creation and retrieval."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
|
||||
|
||||
assert created_batch.id.startswith("batch_")
|
||||
assert len(created_batch.id) > 13
|
||||
assert created_batch.object == "batch"
|
||||
assert created_batch.endpoint == sample_batch_data["endpoint"]
|
||||
assert created_batch.input_file_id == sample_batch_data["input_file_id"]
|
||||
assert created_batch.completion_window == sample_batch_data["completion_window"]
|
||||
assert created_batch.status == "validating"
|
||||
assert created_batch.metadata == sample_batch_data["metadata"]
|
||||
assert isinstance(created_batch.created_at, int)
|
||||
assert created_batch.created_at > 0
|
||||
|
||||
retrieved_batch = await provider.retrieve_batch(created_batch.id)
|
||||
|
||||
self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
|
||||
|
||||
assert retrieved_batch.id == created_batch.id
|
||||
assert retrieved_batch.input_file_id == created_batch.input_file_id
|
||||
assert retrieved_batch.endpoint == created_batch.endpoint
|
||||
assert retrieved_batch.status == created_batch.status
|
||||
assert retrieved_batch.metadata == created_batch.metadata
|
||||
|
||||
async def test_create_batch_without_metadata(self, provider):
|
||||
"""Test batch creation without optional metadata."""
|
||||
batch = await provider.create_batch(
|
||||
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
|
||||
assert batch.metadata is None
|
||||
|
||||
async def test_create_batch_completion_window(self, provider):
|
||||
"""Test batch creation with invalid completion window."""
|
||||
with pytest.raises(ValueError, match="Invalid completion_window"):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
[
|
||||
"/v1/embeddings",
|
||||
"/v1/completions",
|
||||
"/v1/invalid/endpoint",
|
||||
"",
|
||||
],
|
||||
)
|
||||
async def test_create_batch_invalid_endpoints(self, provider, endpoint):
|
||||
"""Test batch creation with various invalid endpoints."""
|
||||
with pytest.raises(ValueError, match="Invalid endpoint"):
|
||||
await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
|
||||
|
||||
async def test_create_batch_invalid_metadata(self, provider):
|
||||
"""Test that batch creation fails with invalid metadata."""
|
||||
with pytest.raises(ValueError, match="should be a valid string"):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={123: "invalid_key"}, # Non-string key
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="should be a valid string"):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={"valid_key": 456}, # Non-string value
|
||||
)
|
||||
|
||||
async def test_retrieve_batch_not_found(self, provider):
|
||||
"""Test error when retrieving non-existent batch."""
|
||||
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
||||
await provider.retrieve_batch("nonexistent_batch")
|
||||
|
||||
async def test_cancel_batch_success(self, provider, sample_batch_data):
|
||||
"""Test successful batch cancellation."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
assert created_batch.status == "validating"
|
||||
|
||||
cancelled_batch = await provider.cancel_batch(created_batch.id)
|
||||
|
||||
assert cancelled_batch.id == created_batch.id
|
||||
assert cancelled_batch.status in ["cancelling", "cancelled"]
|
||||
assert isinstance(cancelled_batch.cancelling_at, int)
|
||||
assert cancelled_batch.cancelling_at >= created_batch.created_at
|
||||
|
||||
@pytest.mark.parametrize("status", ["failed", "expired", "completed"])
|
||||
async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
|
||||
"""Test error when cancelling batch in final states."""
|
||||
provider.process_batches = False
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
# directly update status in kvstore
|
||||
await provider._update_batch(created_batch.id, status=status)
|
||||
|
||||
with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
|
||||
await provider.cancel_batch(created_batch.id)
|
||||
|
||||
async def test_cancel_batch_not_found(self, provider):
|
||||
"""Test error when cancelling non-existent batch."""
|
||||
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
||||
await provider.cancel_batch("nonexistent_batch")
|
||||
|
||||
async def test_list_batches_empty(self, provider):
|
||||
"""Test listing batches when none exist."""
|
||||
response = await provider.list_batches()
|
||||
|
||||
assert response.object == "list"
|
||||
assert response.data == []
|
||||
assert response.first_id is None
|
||||
assert response.last_id is None
|
||||
assert response.has_more is False
|
||||
|
||||
async def test_list_batches_single_batch(self, provider, sample_batch_data):
|
||||
"""Test listing batches with single batch."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
response = await provider.list_batches()
|
||||
|
||||
assert len(response.data) == 1
|
||||
self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
|
||||
assert response.data[0].id == created_batch.id
|
||||
assert response.first_id == created_batch.id
|
||||
assert response.last_id == created_batch.id
|
||||
assert response.has_more is False
|
||||
|
||||
async def test_list_batches_multiple_batches(self, provider):
|
||||
"""Test listing multiple batches."""
|
||||
batches = [
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
response = await provider.list_batches()
|
||||
|
||||
assert len(response.data) == 3
|
||||
|
||||
batch_ids = {batch.id for batch in response.data}
|
||||
expected_ids = {batch.id for batch in batches}
|
||||
assert batch_ids == expected_ids
|
||||
assert response.has_more is False
|
||||
|
||||
assert response.first_id in expected_ids
|
||||
assert response.last_id in expected_ids
|
||||
|
||||
async def test_list_batches_with_limit(self, provider):
|
||||
"""Test listing batches with limit parameter."""
|
||||
batches = [
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
response = await provider.list_batches(limit=2)
|
||||
|
||||
assert len(response.data) == 2
|
||||
assert response.has_more is True
|
||||
assert response.first_id == response.data[0].id
|
||||
assert response.last_id == response.data[1].id
|
||||
batch_ids = {batch.id for batch in response.data}
|
||||
expected_ids = {batch.id for batch in batches}
|
||||
assert batch_ids.issubset(expected_ids)
|
||||
|
||||
async def test_list_batches_with_pagination(self, provider):
|
||||
"""Test listing batches with pagination using 'after' parameter."""
|
||||
for i in range(3):
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
|
||||
# Get first page
|
||||
first_page = await provider.list_batches(limit=1)
|
||||
assert len(first_page.data) == 1
|
||||
assert first_page.has_more is True
|
||||
|
||||
# Get second page using 'after'
|
||||
second_page = await provider.list_batches(limit=1, after=first_page.data[0].id)
|
||||
assert len(second_page.data) == 1
|
||||
assert second_page.data[0].id != first_page.data[0].id
|
||||
|
||||
# Verify we got the next batch in order
|
||||
all_batches = await provider.list_batches()
|
||||
expected_second_batch_id = all_batches.data[1].id
|
||||
assert second_page.data[0].id == expected_second_batch_id
|
||||
|
||||
async def test_list_batches_invalid_after(self, provider, sample_batch_data):
|
||||
"""Test listing batches with invalid 'after' parameter."""
|
||||
await provider.create_batch(**sample_batch_data)
|
||||
|
||||
response = await provider.list_batches(after="nonexistent_batch")
|
||||
|
||||
# Should return all batches (no filtering when 'after' batch not found)
|
||||
assert len(response.data) == 1
|
||||
|
||||
async def test_kvstore_persistence(self, provider, sample_batch_data):
|
||||
"""Test that batches are properly persisted in kvstore."""
|
||||
batch = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
stored_data = await provider.kvstore.get(f"batch:{batch.id}")
|
||||
assert stored_data is not None
|
||||
|
||||
stored_batch_dict = json.loads(stored_data)
|
||||
assert stored_batch_dict["id"] == batch.id
|
||||
assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"]
|
||||
|
||||
async def test_validate_input_file_not_found(self, provider):
|
||||
"""Test _validate_input when input file does not exist."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found"))
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="nonexistent_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
assert errors[0].code == "invalid_request"
|
||||
assert errors[0].message == "Cannot find file nonexistent_file."
|
||||
assert errors[0].param == "input_file_id"
|
||||
assert errors[0].line is None
|
||||
|
||||
async def test_validate_input_file_exists_empty_content(self, provider):
|
||||
"""Test _validate_input when file exists but is empty."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b""
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="empty_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 0
|
||||
assert len(requests) == 0
|
||||
|
||||
async def test_validate_input_file_mixed_valid_invalid_json(self, provider):
|
||||
"""Test _validate_input when file contains valid and invalid JSON lines."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
# Line 1: valid JSON with proper body args, Line 2: invalid JSON
|
||||
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json'
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="mixed_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
# Should have 1 JSON parsing error from line 2, and 1 valid request from line 1
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 1
|
||||
|
||||
assert errors[0].code == "invalid_json_line"
|
||||
assert errors[0].line == 2
|
||||
assert errors[0].message == "This line is not parseable as valid JSON."
|
||||
|
||||
assert requests[0].custom_id == "req-1"
|
||||
assert requests[0].method == "POST"
|
||||
assert requests[0].url == "/v1/chat/completions"
|
||||
assert requests[0].body["model"] == "test-model"
|
||||
assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}]
|
||||
|
||||
async def test_validate_input_invalid_model(self, provider):
|
||||
"""Test _validate_input when file contains request with non-existent model."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}'
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found"))
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="invalid_model_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == "model_not_found"
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported"
|
||||
assert errors[0].param == "body.model"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_name,param_path,error_code,error_message",
|
||||
[
|
||||
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
|
||||
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
|
||||
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
|
||||
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
|
||||
("model", "body.model", "invalid_request", "Model parameter is required"),
|
||||
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
|
||||
"""Test _validate_input when file contains request with missing required parameters."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
base_request = {
|
||||
"custom_id": "req-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
|
||||
}
|
||||
|
||||
# Remove the specific parameter being tested
|
||||
if "." in param_path:
|
||||
top_level, nested_param = param_path.split(".", 1)
|
||||
del base_request[top_level][nested_param]
|
||||
else:
|
||||
del base_request[param_name]
|
||||
|
||||
mock_response.body = json.dumps(base_request).encode()
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=f"missing_{param_name}_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == error_code
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
async def test_validate_input_url_mismatch(self, provider):
|
||||
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}'
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions", # This doesn't match the URL in the request
|
||||
input_file_id="url_mismatch_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == "invalid_url"
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == "URL provided for this request does not match the batch endpoint"
|
||||
assert errors[0].param == "url"
|
||||
|
||||
async def test_validate_input_multiple_errors_per_request(self, provider):
|
||||
"""Test _validate_input when a single request has multiple validation errors."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
# Request missing custom_id, has invalid URL, and missing model in body
|
||||
mock_response.body = (
|
||||
b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}'
|
||||
)
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions", # Doesn't match /v1/embeddings in request
|
||||
input_file_id="multiple_errors_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) >= 2 # At least missing custom_id and URL mismatch
|
||||
assert len(requests) == 0
|
||||
|
||||
for error in errors:
|
||||
assert error.line == 1
|
||||
|
||||
error_codes = {error.code for error in errors}
|
||||
assert "missing_required_parameter" in error_codes # missing custom_id
|
||||
assert "invalid_url" in error_codes # URL mismatch
|
||||
|
||||
async def test_validate_input_invalid_request_format(self, provider):
|
||||
"""Test _validate_input when file contains non-object JSON (array, string, number)."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b'["not", "a", "request", "object"]'
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="invalid_format_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == "invalid_request"
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == "Each line must be a JSON dictionary object"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_name,param_path,invalid_value,error_message",
|
||||
[
|
||||
("custom_id", "custom_id", 12345, "Custom_id must be a string"),
|
||||
("url", "url", 123, "URL must be a string"),
|
||||
("method", "method", ["POST"], "Method must be a string"),
|
||||
("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"),
|
||||
("model", "body.model", 123, "Model must be a string"),
|
||||
("messages", "body.messages", "invalid messages format", "Messages must be an array"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_invalid_parameter_types(
|
||||
self, provider, param_name, param_path, invalid_value, error_message
|
||||
):
|
||||
"""Test _validate_input when file contains request with parameters that have invalid types."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
base_request = {
|
||||
"custom_id": "req-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
|
||||
}
|
||||
|
||||
# Override the specific parameter with invalid value
|
||||
if "." in param_path:
|
||||
top_level, nested_param = param_path.split(".", 1)
|
||||
base_request[top_level][nested_param] = invalid_value
|
||||
else:
|
||||
base_request[param_name] = invalid_value
|
||||
|
||||
mock_response.body = json.dumps(base_request).encode()
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=f"invalid_{param_name}_type_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == "invalid_request"
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
async def test_max_concurrent_batches(self, provider):
|
||||
"""Test max_concurrent_batches configuration and concurrency control."""
|
||||
import asyncio
|
||||
|
||||
provider._batch_semaphore = asyncio.Semaphore(2)
|
||||
|
||||
provider.process_batches = True # enable because we're testing background processing
|
||||
|
||||
active_batches = 0
|
||||
|
||||
async def add_and_wait(batch_id: str):
|
||||
nonlocal active_batches
|
||||
active_batches += 1
|
||||
await asyncio.sleep(float("inf"))
|
||||
|
||||
# the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl,
|
||||
# so we can replace _process_batch_impl with our mock to control concurrency
|
||||
provider._process_batch_impl = add_and_wait
|
||||
|
||||
for _ in range(3):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.042) # let tasks start
|
||||
|
||||
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
|
128
tests/unit/providers/batches/test_reference_idempotency.py
Normal file
128
tests/unit/providers/batches/test_reference_idempotency.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Tests for idempotency functionality in the reference batches provider.
|
||||
|
||||
This module tests the optional idempotency feature that allows clients to provide
|
||||
an idempotency key (idempotency_key) to ensure that repeated requests with the same key
|
||||
and parameters return the same batch, while requests with the same key but different
|
||||
parameters result in a conflict error.
|
||||
|
||||
Test Categories:
|
||||
1. Core Idempotency: Same parameters with same key return same batch
|
||||
2. Parameter Independence: Different parameters without keys create different batches
|
||||
3. Conflict Detection: Same key with different parameters raises ConflictError
|
||||
|
||||
Tests by Category:
|
||||
|
||||
1. Core Idempotency:
|
||||
- test_idempotent_batch_creation_same_params
|
||||
- test_idempotent_batch_creation_metadata_order_independence
|
||||
|
||||
2. Parameter Independence:
|
||||
- test_non_idempotent_behavior_without_key
|
||||
- test_different_idempotency_keys_create_different_batches
|
||||
|
||||
3. Conflict Detection:
|
||||
- test_same_idempotency_key_different_params_conflict (parametrized: input_file_id, metadata values, metadata None vs {})
|
||||
|
||||
Key Behaviors Tested:
|
||||
- Idempotent batch creation when idempotency_key provided with identical parameters
|
||||
- Metadata order independence for consistent batch ID generation
|
||||
- Non-idempotent behavior when no idempotency_key provided (random UUIDs)
|
||||
- Conflict detection for parameter mismatches with same idempotency key
|
||||
- Deterministic ID generation based solely on idempotency key
|
||||
- Proper error handling with detailed conflict messages including key and error codes
|
||||
- Protection against idempotency key reuse with different request parameters
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import ConflictError
|
||||
|
||||
|
||||
class TestReferenceBatchesIdempotency:
|
||||
"""Test suite for idempotency functionality in the reference implementation."""
|
||||
|
||||
async def test_idempotent_batch_creation_same_params(self, provider, sample_batch_data):
|
||||
"""Test that creating batches with identical parameters returns the same batch when idempotency_key is provided."""
|
||||
|
||||
del sample_batch_data["metadata"]
|
||||
|
||||
batch1 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
metadata={"test": "value1", "other": "value2"},
|
||||
idempotency_key="unique-token-1",
|
||||
)
|
||||
|
||||
# sleep for 1 second to allow created_at timestamps to be different
|
||||
await asyncio.sleep(1)
|
||||
|
||||
batch2 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
metadata={"other": "value2", "test": "value1"}, # Different order
|
||||
idempotency_key="unique-token-1",
|
||||
)
|
||||
|
||||
assert batch1.id == batch2.id
|
||||
assert batch1.input_file_id == batch2.input_file_id
|
||||
assert batch1.metadata == batch2.metadata
|
||||
assert batch1.created_at == batch2.created_at
|
||||
|
||||
async def test_different_idempotency_keys_create_different_batches(self, provider, sample_batch_data):
|
||||
"""Test that different idempotency keys create different batches even with same params."""
|
||||
batch1 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
idempotency_key="token-A",
|
||||
)
|
||||
|
||||
batch2 = await provider.create_batch(
|
||||
**sample_batch_data,
|
||||
idempotency_key="token-B",
|
||||
)
|
||||
|
||||
assert batch1.id != batch2.id
|
||||
|
||||
async def test_non_idempotent_behavior_without_key(self, provider, sample_batch_data):
|
||||
"""Test that batches without idempotency key create unique batches even with identical parameters."""
|
||||
batch1 = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
batch2 = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
assert batch1.id != batch2.id
|
||||
assert batch1.input_file_id == batch2.input_file_id
|
||||
assert batch1.endpoint == batch2.endpoint
|
||||
assert batch1.completion_window == batch2.completion_window
|
||||
assert batch1.metadata == batch2.metadata
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_name,first_value,second_value",
|
||||
[
|
||||
("input_file_id", "file_001", "file_002"),
|
||||
("metadata", {"test": "value1"}, {"test": "value2"}),
|
||||
("metadata", None, {}),
|
||||
],
|
||||
)
|
||||
async def test_same_idempotency_key_different_params_conflict(
|
||||
self, provider, sample_batch_data, param_name, first_value, second_value
|
||||
):
|
||||
"""Test that same idempotency_key with different parameters raises conflict error."""
|
||||
sample_batch_data["idempotency_key"] = "same-token"
|
||||
|
||||
sample_batch_data[param_name] = first_value
|
||||
|
||||
batch1 = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
with pytest.raises(ConflictError, match="Idempotency key.*was previously used with different parameters"):
|
||||
sample_batch_data[param_name] = second_value
|
||||
await provider.create_batch(**sample_batch_data)
|
||||
|
||||
retrieved_batch = await provider.retrieve_batch(batch1.id)
|
||||
assert retrieved_batch.id == batch1.id
|
||||
assert getattr(retrieved_batch, param_name) == first_value
|
251
tests/unit/providers/files/test_s3_files.py
Normal file
251
tests/unit/providers/files/test_s3_files.py
Normal file
|
@ -0,0 +1,251 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from botocore.exceptions import ClientError
|
||||
from moto import mock_aws
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
from llama_stack.providers.remote.files.s3 import (
|
||||
S3FilesImplConfig,
|
||||
get_adapter_impl,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
|
||||
self.content = content
|
||||
self.filename = filename
|
||||
self.content_type = content_type
|
||||
|
||||
async def read(self):
|
||||
return self.content
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_config(tmp_path):
|
||||
db_path = tmp_path / "s3_files_metadata.db"
|
||||
|
||||
return S3FilesImplConfig(
|
||||
bucket_name="test-bucket",
|
||||
region="not-a-region",
|
||||
auto_create_bucket=True,
|
||||
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_client():
|
||||
"""Create a mocked S3 client for testing."""
|
||||
# we use `with mock_aws()` because @mock_aws decorator does not support being a generator
|
||||
with mock_aws():
|
||||
# must yield or the mock will be reset before it is used
|
||||
yield boto3.client("s3")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def s3_provider(s3_config, s3_client):
|
||||
"""Create an S3 files provider with mocked S3 for testing."""
|
||||
provider = await get_adapter_impl(s3_config, {})
|
||||
yield provider
|
||||
await provider.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text_file():
|
||||
content = b"Hello, this is a test file for the S3 Files API!"
|
||||
return MockUploadFile(content, "sample_text_file.txt")
|
||||
|
||||
|
||||
class TestS3FilesImpl:
|
||||
"""Test suite for S3 Files implementation."""
|
||||
|
||||
async def test_upload_file(self, s3_provider, sample_text_file, s3_client, s3_config):
|
||||
"""Test successful file upload."""
|
||||
sample_text_file.filename = "test_upload_file"
|
||||
result = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
assert result.filename == sample_text_file.filename
|
||||
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
|
||||
assert result.bytes == len(sample_text_file.content)
|
||||
assert result.id.startswith("file-")
|
||||
|
||||
# Verify file exists in S3 backend
|
||||
response = s3_client.head_object(Bucket=s3_config.bucket_name, Key=result.id)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
async def test_list_files_empty(self, s3_provider):
|
||||
"""Test listing files when no files exist."""
|
||||
result = await s3_provider.openai_list_files()
|
||||
|
||||
assert len(result.data) == 0
|
||||
assert not result.has_more
|
||||
assert result.first_id == ""
|
||||
assert result.last_id == ""
|
||||
|
||||
async def test_retrieve_file(self, s3_provider, sample_text_file):
|
||||
"""Test retrieving file metadata."""
|
||||
sample_text_file.filename = "test_retrieve_file"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
retrieved = await s3_provider.openai_retrieve_file(uploaded.id)
|
||||
|
||||
assert retrieved.id == uploaded.id
|
||||
assert retrieved.filename == uploaded.filename
|
||||
assert retrieved.purpose == uploaded.purpose
|
||||
assert retrieved.bytes == uploaded.bytes
|
||||
|
||||
async def test_retrieve_file_content(self, s3_provider, sample_text_file):
|
||||
"""Test retrieving file content."""
|
||||
sample_text_file.filename = "test_retrieve_file_content"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
response = await s3_provider.openai_retrieve_file_content(uploaded.id)
|
||||
|
||||
assert response.body == sample_text_file.content
|
||||
assert response.headers["Content-Disposition"] == f'attachment; filename="{sample_text_file.filename}"'
|
||||
|
||||
async def test_delete_file(self, s3_provider, sample_text_file, s3_config, s3_client):
|
||||
"""Test deleting a file."""
|
||||
sample_text_file.filename = "test_delete_file"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
delete_response = await s3_provider.openai_delete_file(uploaded.id)
|
||||
|
||||
assert delete_response.id == uploaded.id
|
||||
assert delete_response.deleted is True
|
||||
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_retrieve_file(uploaded.id)
|
||||
|
||||
# Verify file is gone from S3 backend
|
||||
with pytest.raises(ClientError) as exc_info:
|
||||
s3_client.head_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
|
||||
assert exc_info.value.response["Error"]["Code"] == "404"
|
||||
|
||||
async def test_list_files(self, s3_provider, sample_text_file):
|
||||
"""Test listing files after uploading some."""
|
||||
sample_text_file.filename = "test_list_files_with_content_file1"
|
||||
file1 = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
file2_content = MockUploadFile(b"Second file content", "test_list_files_with_content_file2")
|
||||
file2 = await s3_provider.openai_upload_file(
|
||||
file=file2_content,
|
||||
purpose=OpenAIFilePurpose.BATCH,
|
||||
)
|
||||
|
||||
result = await s3_provider.openai_list_files()
|
||||
|
||||
assert len(result.data) == 2
|
||||
file_ids = {f.id for f in result.data}
|
||||
assert file1.id in file_ids
|
||||
assert file2.id in file_ids
|
||||
|
||||
async def test_list_files_with_purpose_filter(self, s3_provider, sample_text_file):
|
||||
"""Test listing files with purpose filter."""
|
||||
sample_text_file.filename = "test_list_files_with_purpose_filter_file1"
|
||||
file1 = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
file2_content = MockUploadFile(b"Batch file content", "test_list_files_with_purpose_filter_file2")
|
||||
await s3_provider.openai_upload_file(
|
||||
file=file2_content,
|
||||
purpose=OpenAIFilePurpose.BATCH,
|
||||
)
|
||||
|
||||
result = await s3_provider.openai_list_files(purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].id == file1.id
|
||||
assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS
|
||||
|
||||
async def test_nonexistent_file_retrieval(self, s3_provider):
|
||||
"""Test retrieving a non-existent file raises error."""
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_retrieve_file("file-nonexistent")
|
||||
|
||||
async def test_nonexistent_file_content_retrieval(self, s3_provider):
|
||||
"""Test retrieving content of a non-existent file raises error."""
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_retrieve_file_content("file-nonexistent")
|
||||
|
||||
async def test_nonexistent_file_deletion(self, s3_provider):
|
||||
"""Test deleting a non-existent file raises error."""
|
||||
with pytest.raises(ResourceNotFoundError, match="not found"):
|
||||
await s3_provider.openai_delete_file("file-nonexistent")
|
||||
|
||||
async def test_upload_file_without_filename(self, s3_provider, sample_text_file):
|
||||
"""Test uploading a file without a filename uses the fallback."""
|
||||
del sample_text_file.filename
|
||||
result = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
|
||||
assert result.bytes == len(sample_text_file.content)
|
||||
|
||||
retrieved = await s3_provider.openai_retrieve_file(result.id)
|
||||
assert retrieved.filename == result.filename
|
||||
|
||||
async def test_file_operations_when_s3_object_deleted(self, s3_provider, sample_text_file, s3_config, s3_client):
|
||||
"""Test file operations when S3 object is deleted but metadata exists (negative test)."""
|
||||
sample_text_file.filename = "test_orphaned_metadata"
|
||||
uploaded = await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
# Directly delete the S3 object from the backend
|
||||
s3_client.delete_object(Bucket=s3_config.bucket_name, Key=uploaded.id)
|
||||
|
||||
with pytest.raises(ResourceNotFoundError, match="not found") as exc_info:
|
||||
await s3_provider.openai_retrieve_file_content(uploaded.id)
|
||||
assert uploaded.id in str(exc_info).lower()
|
||||
|
||||
listed_files = await s3_provider.openai_list_files()
|
||||
assert uploaded.id not in [file.id for file in listed_files.data]
|
||||
|
||||
async def test_upload_file_s3_put_object_failure(self, s3_provider, sample_text_file, s3_config, s3_client):
|
||||
"""Test that put_object failure results in exception and no orphaned metadata."""
|
||||
sample_text_file.filename = "test_s3_put_object_failure"
|
||||
|
||||
def failing_put_object(*args, **kwargs):
|
||||
raise ClientError(
|
||||
error_response={"Error": {"Code": "SolarRadiation", "Message": "Bloop"}}, operation_name="PutObject"
|
||||
)
|
||||
|
||||
with patch.object(s3_provider.client, "put_object", side_effect=failing_put_object):
|
||||
with pytest.raises(RuntimeError, match="Failed to upload file to S3"):
|
||||
await s3_provider.openai_upload_file(
|
||||
file=sample_text_file,
|
||||
purpose=OpenAIFilePurpose.ASSISTANTS,
|
||||
)
|
||||
|
||||
files_list = await s3_provider.openai_list_files()
|
||||
assert len(files_list.data) == 0, "No file metadata should remain after failed upload"
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import logging # allow-direct-logging
|
||||
import threading
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
convert_message_to_openai_dict_new,
|
||||
openai_messages_to_messages,
|
||||
)
|
||||
|
||||
|
@ -182,3 +183,42 @@ def test_user_message_accepts_images():
|
|||
assert len(msg.content) == 2
|
||||
assert msg.content[0].text == "Describe this image:"
|
||||
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_new_user_message():
|
||||
"""Test convert_message_to_openai_dict_new with UserMessage."""
|
||||
message = UserMessage(content="Hello, world!", role="user")
|
||||
result = await convert_message_to_openai_dict_new(message)
|
||||
|
||||
assert result["role"] == "user"
|
||||
assert result["content"] == "Hello, world!"
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
|
||||
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
|
||||
message = CompletionMessage(
|
||||
content="I'll help you find the weather.",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="call_123",
|
||||
tool_name="get_weather",
|
||||
arguments={"city": "Sligo"},
|
||||
arguments_json='{"city": "Sligo"}',
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
result = await convert_message_to_openai_dict_new(message)
|
||||
|
||||
# This would have failed with "Cannot instantiate typing.Union" before the fix
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] == "I'll help you find the weather."
|
||||
assert "tool_calls" in result
|
||||
assert result["tool_calls"] is not None
|
||||
assert len(result["tool_calls"]) == 1
|
||||
|
||||
tool_call = result["tool_calls"][0]
|
||||
assert tool_call.id == "call_123"
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_weather"
|
||||
assert tool_call.function.arguments == '{"city": "Sligo"}'
|
||||
|
|
105
tests/unit/server/test_cors.py
Normal file
105
tests/unit/server/test_cors.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.datatypes import CORSConfig, process_cors_config
|
||||
|
||||
|
||||
def test_cors_config_defaults():
|
||||
config = CORSConfig()
|
||||
|
||||
assert config.allow_origins == []
|
||||
assert config.allow_origin_regex is None
|
||||
assert config.allow_methods == ["OPTIONS"]
|
||||
assert config.allow_headers == []
|
||||
assert config.allow_credentials is False
|
||||
assert config.expose_headers == []
|
||||
assert config.max_age == 600
|
||||
|
||||
|
||||
def test_cors_config_explicit_config():
|
||||
config = CORSConfig(
|
||||
allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"]
|
||||
)
|
||||
|
||||
assert config.allow_origins == ["https://example.com"]
|
||||
assert config.allow_credentials is True
|
||||
assert config.max_age == 3600
|
||||
assert config.allow_methods == ["GET", "POST"]
|
||||
|
||||
|
||||
def test_cors_config_regex():
|
||||
config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
|
||||
|
||||
assert config.allow_origins == []
|
||||
assert config.allow_origin_regex == r"https?://localhost:\d+"
|
||||
|
||||
|
||||
def test_cors_config_wildcard_credentials_error():
|
||||
with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
|
||||
CORSConfig(allow_origins=["*"], allow_credentials=True)
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
|
||||
CORSConfig(allow_origins=["https://example.com", "*"], allow_credentials=True)
|
||||
|
||||
|
||||
def test_process_cors_config_false():
|
||||
result = process_cors_config(False)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_process_cors_config_true():
|
||||
result = process_cors_config(True)
|
||||
|
||||
assert isinstance(result, CORSConfig)
|
||||
assert result.allow_origins == []
|
||||
assert result.allow_origin_regex == r"https?://localhost:\d+"
|
||||
assert result.allow_credentials is False
|
||||
expected_methods = ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
|
||||
for method in expected_methods:
|
||||
assert method in result.allow_methods
|
||||
|
||||
|
||||
def test_process_cors_config_passthrough():
|
||||
original = CORSConfig(allow_origins=["https://example.com"], allow_methods=["GET"])
|
||||
result = process_cors_config(original)
|
||||
|
||||
assert result is original
|
||||
|
||||
|
||||
def test_process_cors_config_invalid_type():
|
||||
with pytest.raises(ValueError, match="Expected bool or CORSConfig, got str"):
|
||||
process_cors_config("invalid")
|
||||
|
||||
|
||||
def test_cors_config_model_dump():
|
||||
cors_config = CORSConfig(
|
||||
allow_origins=["https://example.com"],
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["Content-Type"],
|
||||
allow_credentials=True,
|
||||
max_age=3600,
|
||||
)
|
||||
|
||||
config_dict = cors_config.model_dump()
|
||||
|
||||
assert config_dict["allow_origins"] == ["https://example.com"]
|
||||
assert config_dict["allow_methods"] == ["GET", "POST"]
|
||||
assert config_dict["allow_headers"] == ["Content-Type"]
|
||||
assert config_dict["allow_credentials"] is True
|
||||
assert config_dict["max_age"] == 3600
|
||||
|
||||
expected_keys = {
|
||||
"allow_origins",
|
||||
"allow_origin_regex",
|
||||
"allow_methods",
|
||||
"allow_headers",
|
||||
"allow_credentials",
|
||||
"expose_headers",
|
||||
"max_age",
|
||||
}
|
||||
assert set(config_dict.keys()) == expected_keys
|
Loading…
Add table
Add a link
Reference in a new issue