From ff60bb31e6d7ae310957442a9a0be2c986930953 Mon Sep 17 00:00:00 2001 From: Mike Sager Date: Thu, 13 Nov 2025 14:22:24 -0500 Subject: [PATCH] Restore responses unit tests --- .../meta_reference/fixtures/__init__.py | 23 + .../fixtures/simple_chat_completion.yaml | 9 + .../fixtures/tool_call_completion.yaml | 14 + .../meta_reference/test_openai_responses.py | 1244 +++++++++++++++++ .../test_openai_responses_conversations.py | 249 ++++ .../test_response_conversion_utils.py | 367 +++++ .../test_response_tool_context.py | 183 +++ .../test_responses_safety_utils.py | 155 ++ 8 files changed, 2244 insertions(+) create mode 100644 tests/unit/providers/agents/meta_reference/fixtures/__init__.py create mode 100644 tests/unit/providers/agents/meta_reference/fixtures/simple_chat_completion.yaml create mode 100644 tests/unit/providers/agents/meta_reference/fixtures/tool_call_completion.yaml create mode 100644 tests/unit/providers/agents/meta_reference/test_openai_responses.py create mode 100644 tests/unit/providers/agents/meta_reference/test_openai_responses_conversations.py create mode 100644 tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py create mode 100644 tests/unit/providers/agents/meta_reference/test_response_tool_context.py create mode 100644 tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py diff --git a/tests/unit/providers/agents/meta_reference/fixtures/__init__.py b/tests/unit/providers/agents/meta_reference/fixtures/__init__.py new file mode 100644 index 000000000..2ebcd9970 --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/fixtures/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +import yaml + +from llama_stack.apis.inference import ( + OpenAIChatCompletion, +) + +FIXTURES_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def load_chat_completion_fixture(filename: str) -> OpenAIChatCompletion: + fixture_path = os.path.join(FIXTURES_DIR, filename) + + with open(fixture_path) as f: + data = yaml.safe_load(f) + return OpenAIChatCompletion(**data) diff --git a/tests/unit/providers/agents/meta_reference/fixtures/simple_chat_completion.yaml b/tests/unit/providers/agents/meta_reference/fixtures/simple_chat_completion.yaml new file mode 100644 index 000000000..4959349a0 --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/fixtures/simple_chat_completion.yaml @@ -0,0 +1,9 @@ +id: chat-completion-123 +choices: + - message: + content: "Dublin" + role: assistant + finish_reason: stop + index: 0 +created: 1234567890 +model: meta-llama/Llama-3.1-8B-Instruct diff --git a/tests/unit/providers/agents/meta_reference/fixtures/tool_call_completion.yaml b/tests/unit/providers/agents/meta_reference/fixtures/tool_call_completion.yaml new file mode 100644 index 000000000..f6532e3a9 --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/fixtures/tool_call_completion.yaml @@ -0,0 +1,14 @@ +id: chat-completion-123 +choices: + - message: + tool_calls: + - id: tool_call_123 + type: function + function: + name: web_search + arguments: '{"query":"What is the capital of Ireland?"}' + role: assistant + finish_reason: stop + index: 0 +created: 1234567890 +model: meta-llama/Llama-3.1-8B-Instruct diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py new file mode 100644 index 000000000..ba914d808 --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -0,0 +1,1244 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, patch + +import pytest +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) + +from llama_stack.apis.agents import Order +from llama_stack.apis.agents.openai_responses import ( + ListOpenAIResponseInputItem, + OpenAIResponseInputMessageContentText, + OpenAIResponseInputToolFunction, + OpenAIResponseInputToolMCP, + OpenAIResponseInputToolWebSearch, + OpenAIResponseMessage, + OpenAIResponseOutputMessageContentOutputText, + OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseOutputMessageMCPCall, + OpenAIResponseOutputMessageWebSearchToolCall, + OpenAIResponseText, + OpenAIResponseTextFormat, + WebSearchToolTypes, +) +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIChatCompletionContentPartTextParam, + OpenAIChatCompletionRequestWithExtraBody, + OpenAIDeveloperMessageParam, + OpenAIJSONSchema, + OpenAIResponseFormatJSONObject, + OpenAIResponseFormatJSONSchema, + OpenAIUserMessageParam, +) +from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime +from llama_stack.core.access_control.access_control import default_policy +from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig +from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( + OpenAIResponsesImpl, +) +from llama_stack.providers.utils.responses.responses_store import ( + ResponsesStore, + _OpenAIResponseObjectWithInputAndMessages, +) +from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends +from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture + + +@pytest.fixture +def mock_inference_api(): + inference_api = AsyncMock() + return inference_api + + +@pytest.fixture +def mock_tool_groups_api(): + tool_groups_api = AsyncMock(spec=ToolGroups) + return tool_groups_api + + +@pytest.fixture +def mock_tool_runtime_api(): + tool_runtime_api = AsyncMock(spec=ToolRuntime) + return tool_runtime_api + + +@pytest.fixture +def mock_responses_store(): + responses_store = AsyncMock(spec=ResponsesStore) + return responses_store + + +@pytest.fixture +def mock_vector_io_api(): + vector_io_api = AsyncMock() + return vector_io_api + + +@pytest.fixture +def mock_conversations_api(): + """Mock conversations API for testing.""" + mock_api = AsyncMock() + return mock_api + + +@pytest.fixture +def mock_safety_api(): + safety_api = AsyncMock() + return safety_api + + +@pytest.fixture +def openai_responses_impl( + mock_inference_api, + mock_tool_groups_api, + mock_tool_runtime_api, + mock_responses_store, + mock_vector_io_api, + mock_safety_api, + mock_conversations_api, +): + return OpenAIResponsesImpl( + inference_api=mock_inference_api, + tool_groups_api=mock_tool_groups_api, + tool_runtime_api=mock_tool_runtime_api, + responses_store=mock_responses_store, + vector_io_api=mock_vector_io_api, + safety_api=mock_safety_api, + conversations_api=mock_conversations_api, + ) + + +async def fake_stream(fixture: str = "simple_chat_completion.yaml"): + value = load_chat_completion_fixture(fixture) + yield ChatCompletionChunk( + id=value.id, + choices=[ + Choice( + index=0, + delta=ChoiceDelta( + content=c.message.content, + role=c.message.role, + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id=t.id, + function=ChoiceDeltaToolCallFunction( + name=t.function.name, + arguments=t.function.arguments, + ), + ) + for t in (c.message.tool_calls or []) + ], + ), + ) + for c in value.choices + ], + created=1, + model=value.model, + object="chat.completion.chunk", + ) + + +async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api): + """Test creating an OpenAI response with a simple string input.""" + # Setup + input_text = "What is the capital of Ireland?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + # Load the chat completion fixture + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + temperature=0.1, + stream=True, # Enable streaming to test content part events + ) + + # For streaming response, collect all chunks + chunks = [chunk async for chunk in result] + + mock_inference_api.openai_chat_completion.assert_called_once_with( + OpenAIChatCompletionRequestWithExtraBody( + model=model, + messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)], + response_format=None, + tools=None, + stream=True, + temperature=0.1, + stream_options={ + "include_usage": True, + }, + ) + ) + + # Should have content part events for text streaming + # Expected: response.created, response.in_progress, content_part.added, output_text.delta, content_part.done, response.completed + assert len(chunks) >= 5 + assert chunks[0].type == "response.created" + assert any(chunk.type == "response.in_progress" for chunk in chunks) + + # Check for content part events + content_part_added_events = [c for c in chunks if c.type == "response.content_part.added"] + content_part_done_events = [c for c in chunks if c.type == "response.content_part.done"] + text_delta_events = [c for c in chunks if c.type == "response.output_text.delta"] + + assert len(content_part_added_events) >= 1, "Should have content_part.added event for text" + assert len(content_part_done_events) >= 1, "Should have content_part.done event for text" + assert len(text_delta_events) >= 1, "Should have text delta events" + + added_event = content_part_added_events[0] + done_event = content_part_done_events[0] + assert added_event.content_index == 0 + assert done_event.content_index == 0 + assert added_event.output_index == done_event.output_index == 0 + assert added_event.item_id == done_event.item_id + assert added_event.response_id == done_event.response_id + + # Verify final event is completion + assert chunks[-1].type == "response.completed" + + # When streaming, the final response is in the last chunk + final_response = chunks[-1].response + assert final_response.model == model + assert len(final_response.output) == 1 + assert isinstance(final_response.output[0], OpenAIResponseMessage) + assert final_response.output[0].id == added_event.item_id + assert final_response.id == added_event.response_id + + openai_responses_impl.responses_store.store_response_object.assert_called_once() + assert final_response.output[0].content[0].text == "Dublin" + + +async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api): + """Test creating an OpenAI response with a simple string input and tools.""" + # Setup + input_text = "What is the capital of Ireland?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + openai_responses_impl.tool_groups_api.get_tool.return_value = ToolDef( + name="web_search", + toolgroup_id="web_search", + description="Search the web for information", + input_schema={ + "type": "object", + "properties": {"query": {"type": "string", "description": "The query to search for"}}, + "required": ["query"], + }, + ) + + openai_responses_impl.tool_runtime_api.invoke_tool.return_value = ToolInvocationResult( + status="completed", + content="Dublin", + ) + + # Execute + for tool_name in WebSearchToolTypes: + # Reset mock states as we loop through each tool type + mock_inference_api.openai_chat_completion.side_effect = [ + fake_stream("tool_call_completion.yaml"), + fake_stream(), + ] + openai_responses_impl.tool_groups_api.get_tool.reset_mock() + openai_responses_impl.tool_runtime_api.invoke_tool.reset_mock() + openai_responses_impl.responses_store.store_response_object.reset_mock() + + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + temperature=0.1, + tools=[ + OpenAIResponseInputToolWebSearch( + name=tool_name, + ) + ], + ) + + # Verify + first_call = mock_inference_api.openai_chat_completion.call_args_list[0] + first_params = first_call.args[0] + assert first_params.messages[0].content == "What is the capital of Ireland?" + assert first_params.tools is not None + assert first_params.temperature == 0.1 + + second_call = mock_inference_api.openai_chat_completion.call_args_list[1] + second_params = second_call.args[0] + assert second_params.messages[-1].content == "Dublin" + assert second_params.temperature == 0.1 + + openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search") + openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with( + tool_name="web_search", + kwargs={"query": "What is the capital of Ireland?"}, + ) + + openai_responses_impl.responses_store.store_response_object.assert_called_once() + + # Check that we got the content from our mocked tool execution result + assert len(result.output) >= 1 + assert isinstance(result.output[1], OpenAIResponseMessage) + assert result.output[1].content[0].text == "Dublin" + assert result.output[1].content[0].annotations == [] + + +async def test_create_openai_response_with_tool_call_type_none(openai_responses_impl, mock_inference_api): + """Test creating an OpenAI response with a tool call response that has a type of None.""" + # Setup + input_text = "How hot it is in San Francisco today?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + async def fake_stream_toolcall(): + yield ChatCompletionChunk( + id="123", + choices=[ + Choice( + index=0, + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="tc_123", + function=ChoiceDeltaToolCallFunction(name="get_weather", arguments="{}"), + type=None, + ) + ] + ), + ), + ], + created=1, + model=model, + object="chat.completion.chunk", + ) + + mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall() + + # Execute + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + stream=True, + temperature=0.1, + tools=[ + OpenAIResponseInputToolFunction( + name="get_weather", + description="Get current temperature for a given location.", + parameters={ + "location": "string", + }, + ) + ], + ) + + # Check that we got the content from our mocked tool execution result + chunks = [chunk async for chunk in result] + + # Verify event types + # Should have: response.created, response.in_progress, output_item.added, + # function_call_arguments.delta, function_call_arguments.done, output_item.done, response.completed + assert len(chunks) == 7 + + event_types = [chunk.type for chunk in chunks] + assert event_types == [ + "response.created", + "response.in_progress", + "response.output_item.added", + "response.function_call_arguments.delta", + "response.function_call_arguments.done", + "response.output_item.done", + "response.completed", + ] + + # Verify inference API was called correctly (after iterating over result) + first_call = mock_inference_api.openai_chat_completion.call_args_list[0] + first_params = first_call.args[0] + assert first_params.messages[0].content == input_text + assert first_params.tools is not None + assert first_params.temperature == 0.1 + + # Check response.created event (should have empty output) + assert len(chunks[0].response.output) == 0 + + # Check response.completed event (should have the tool call) + completed_chunk = chunks[-1] + assert completed_chunk.type == "response.completed" + assert len(completed_chunk.response.output) == 1 + assert completed_chunk.response.output[0].type == "function_call" + assert completed_chunk.response.output[0].name == "get_weather" + + +async def test_create_openai_response_with_tool_call_function_arguments_none(openai_responses_impl, mock_inference_api): + """Test creating an OpenAI response with tool calls that omit arguments.""" + + input_text = "What is the time right now?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + async def fake_stream_toolcall(): + yield ChatCompletionChunk( + id="123", + choices=[ + Choice( + index=0, + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="tc_123", + function=ChoiceDeltaToolCallFunction(name="get_current_time", arguments=None), + type=None, + ) + ] + ), + ), + ], + created=1, + model=model, + object="chat.completion.chunk", + ) + + def assert_common_expectations(chunks) -> None: + first_call = mock_inference_api.openai_chat_completion.call_args_list[0] + first_params = first_call.args[0] + assert first_params.messages[0].content == input_text + assert first_params.tools is not None + assert first_params.temperature == 0.1 + assert len(chunks[0].response.output) == 0 + completed_chunk = chunks[-1] + assert completed_chunk.type == "response.completed" + assert len(completed_chunk.response.output) == 1 + assert completed_chunk.response.output[0].type == "function_call" + assert completed_chunk.response.output[0].name == "get_current_time" + assert completed_chunk.response.output[0].arguments == "{}" + + # Function does not accept arguments + mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall() + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + stream=True, + temperature=0.1, + tools=[ + OpenAIResponseInputToolFunction( + name="get_current_time", description="Get current time for system's timezone", parameters={} + ) + ], + ) + chunks = [chunk async for chunk in result] + assert [chunk.type for chunk in chunks] == [ + "response.created", + "response.in_progress", + "response.output_item.added", + "response.function_call_arguments.done", + "response.output_item.done", + "response.completed", + ] + assert_common_expectations(chunks) + + # Function accepts optional arguments + mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall() + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + stream=True, + temperature=0.1, + tools=[ + OpenAIResponseInputToolFunction( + name="get_current_time", + description="Get current time for system's timezone", + parameters={"timezone": "string"}, + ) + ], + ) + chunks = [chunk async for chunk in result] + assert [chunk.type for chunk in chunks] == [ + "response.created", + "response.in_progress", + "response.output_item.added", + "response.function_call_arguments.done", + "response.output_item.done", + "response.completed", + ] + assert_common_expectations(chunks) + + # Function accepts optional arguments with additional optional fields + mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall() + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + stream=True, + temperature=0.1, + tools=[ + OpenAIResponseInputToolFunction( + name="get_current_time", + description="Get current time for system's timezone", + parameters={"timezone": "string", "location": "string"}, + ) + ], + ) + chunks = [chunk async for chunk in result] + assert [chunk.type for chunk in chunks] == [ + "response.created", + "response.in_progress", + "response.output_item.added", + "response.function_call_arguments.done", + "response.output_item.done", + "response.completed", + ] + assert_common_expectations(chunks) + mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall() + + +async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api): + """Test creating an OpenAI response with multiple messages.""" + # Setup + input_messages = [ + OpenAIResponseMessage(role="developer", content="You are a helpful assistant", name=None), + OpenAIResponseMessage(role="user", content="Name some towns in Ireland", name=None), + OpenAIResponseMessage( + role="assistant", + content=[ + OpenAIResponseInputMessageContentText(text="Galway, Longford, Sligo"), + OpenAIResponseInputMessageContentText(text="Dublin"), + ], + name=None, + ), + OpenAIResponseMessage(role="user", content="Which is the largest town in Ireland?", name=None), + ] + model = "meta-llama/Llama-3.1-8B-Instruct" + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + await openai_responses_impl.create_openai_response( + input=input_messages, + model=model, + temperature=0.1, + ) + + # Verify the the correct messages were sent to the inference API i.e. + # All of the responses message were convered to the chat completion message objects + call_args = mock_inference_api.openai_chat_completion.call_args_list[0] + params = call_args.args[0] + inference_messages = params.messages + for i, m in enumerate(input_messages): + if isinstance(m.content, str): + assert inference_messages[i].content == m.content + else: + assert inference_messages[i].content[0].text == m.content[0].text + assert isinstance(inference_messages[i].content[0], OpenAIChatCompletionContentPartTextParam) + assert inference_messages[i].role == m.role + if m.role == "user": + assert isinstance(inference_messages[i], OpenAIUserMessageParam) + elif m.role == "assistant": + assert isinstance(inference_messages[i], OpenAIAssistantMessageParam) + else: + assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam) + + +async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store): + """Test prepending a basic previous response to a new response.""" + + input_item_message = OpenAIResponseMessage( + id="123", + content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")], + role="user", + ) + response_output_message = OpenAIResponseMessage( + id="123", + content=[OpenAIResponseOutputMessageContentOutputText(text="fake_response")], + status="completed", + role="assistant", + ) + previous_response = _OpenAIResponseObjectWithInputAndMessages( + created_at=1, + id="resp_123", + model="fake_model", + output=[response_output_message], + status="completed", + text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), + input=[input_item_message], + messages=[OpenAIUserMessageParam(content="fake_previous_input")], + ) + mock_responses_store.get_response_object.return_value = previous_response + + input = await openai_responses_impl._prepend_previous_response("fake_input", previous_response) + + assert len(input) == 3 + # Check for previous input + assert isinstance(input[0], OpenAIResponseMessage) + assert input[0].content[0].text == "fake_previous_input" + # Check for previous output + assert isinstance(input[1], OpenAIResponseMessage) + assert input[1].content[0].text == "fake_response" + # Check for new input + assert isinstance(input[2], OpenAIResponseMessage) + assert input[2].content == "fake_input" + + +async def test_prepend_previous_response_web_search(openai_responses_impl, mock_responses_store): + """Test prepending a web search previous response to a new response.""" + input_item_message = OpenAIResponseMessage( + id="123", + content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")], + role="user", + ) + output_web_search = OpenAIResponseOutputMessageWebSearchToolCall( + id="ws_123", + status="completed", + ) + output_message = OpenAIResponseMessage( + id="123", + content=[OpenAIResponseOutputMessageContentOutputText(text="fake_web_search_response")], + status="completed", + role="assistant", + ) + response = _OpenAIResponseObjectWithInputAndMessages( + created_at=1, + id="resp_123", + model="fake_model", + output=[output_web_search, output_message], + status="completed", + text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), + input=[input_item_message], + messages=[OpenAIUserMessageParam(content="test input")], + ) + mock_responses_store.get_response_object.return_value = response + + input_messages = [OpenAIResponseMessage(content="fake_input", role="user")] + input = await openai_responses_impl._prepend_previous_response(input_messages, response) + + assert len(input) == 4 + # Check for previous input + assert isinstance(input[0], OpenAIResponseMessage) + assert input[0].content[0].text == "fake_previous_input" + # Check for previous output web search tool call + assert isinstance(input[1], OpenAIResponseOutputMessageWebSearchToolCall) + # Check for previous output web search response + assert isinstance(input[2], OpenAIResponseMessage) + assert input[2].content[0].text == "fake_web_search_response" + # Check for new input + assert isinstance(input[3], OpenAIResponseMessage) + assert input[3].content == "fake_input" + + +async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mock_responses_store): + """Test prepending a previous response which included an mcp tool call to a new response.""" + input_item_message = OpenAIResponseMessage( + id="123", + content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")], + role="user", + ) + output_tool_call = OpenAIResponseOutputMessageMCPCall( + id="ws_123", + name="fake-tool", + arguments="fake-arguments", + server_label="fake-label", + ) + output_message = OpenAIResponseMessage( + id="123", + content=[OpenAIResponseOutputMessageContentOutputText(text="fake_tool_call_response")], + status="completed", + role="assistant", + ) + response = _OpenAIResponseObjectWithInputAndMessages( + created_at=1, + id="resp_123", + model="fake_model", + output=[output_tool_call, output_message], + status="completed", + text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), + input=[input_item_message], + messages=[OpenAIUserMessageParam(content="test input")], + ) + mock_responses_store.get_response_object.return_value = response + + input_messages = [OpenAIResponseMessage(content="fake_input", role="user")] + input = await openai_responses_impl._prepend_previous_response(input_messages, response) + + assert len(input) == 4 + # Check for previous input + assert isinstance(input[0], OpenAIResponseMessage) + assert input[0].content[0].text == "fake_previous_input" + # Check for previous output MCP tool call + assert isinstance(input[1], OpenAIResponseOutputMessageMCPCall) + # Check for previous output web search response + assert isinstance(input[2], OpenAIResponseMessage) + assert input[2].content[0].text == "fake_tool_call_response" + # Check for new input + assert isinstance(input[3], OpenAIResponseMessage) + assert input[3].content == "fake_input" + + +async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api): + # Setup + input_text = "What is the capital of Ireland?" + model = "meta-llama/Llama-3.1-8B-Instruct" + instructions = "You are a geography expert. Provide concise answers." + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + instructions=instructions, + ) + + # Verify + mock_inference_api.openai_chat_completion.assert_called_once() + call_args = mock_inference_api.openai_chat_completion.call_args + params = call_args.args[0] + sent_messages = params.messages + + # Check that instructions were prepended as a system message + assert len(sent_messages) == 2 + assert sent_messages[0].role == "system" + assert sent_messages[0].content == instructions + assert sent_messages[1].role == "user" + assert sent_messages[1].content == input_text + + +async def test_create_openai_response_with_instructions_and_multiple_messages( + openai_responses_impl, mock_inference_api +): + # Setup + input_messages = [ + OpenAIResponseMessage(role="user", content="Name some towns in Ireland", name=None), + OpenAIResponseMessage( + role="assistant", + content="Galway, Longford, Sligo", + name=None, + ), + OpenAIResponseMessage(role="user", content="Which is the largest?", name=None), + ] + model = "meta-llama/Llama-3.1-8B-Instruct" + instructions = "You are a geography expert. Provide concise answers." + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + await openai_responses_impl.create_openai_response( + input=input_messages, + model=model, + instructions=instructions, + ) + + # Verify + mock_inference_api.openai_chat_completion.assert_called_once() + call_args = mock_inference_api.openai_chat_completion.call_args + params = call_args.args[0] + sent_messages = params.messages + + # Check that instructions were prepended as a system message + assert len(sent_messages) == 4 # 1 system + 3 input messages + assert sent_messages[0].role == "system" + assert sent_messages[0].content == instructions + + # Check the rest of the messages were converted correctly + assert sent_messages[1].role == "user" + assert sent_messages[1].content == "Name some towns in Ireland" + assert sent_messages[2].role == "assistant" + assert sent_messages[2].content == "Galway, Longford, Sligo" + assert sent_messages[3].role == "user" + assert sent_messages[3].content == "Which is the largest?" + + +async def test_create_openai_response_with_instructions_and_previous_response( + openai_responses_impl, mock_responses_store, mock_inference_api +): + """Test prepending both instructions and previous response.""" + + input_item_message = OpenAIResponseMessage( + id="123", + content="Name some towns in Ireland", + role="user", + ) + response_output_message = OpenAIResponseMessage( + id="123", + content="Galway, Longford, Sligo", + status="completed", + role="assistant", + ) + response = _OpenAIResponseObjectWithInputAndMessages( + created_at=1, + id="resp_123", + model="fake_model", + output=[response_output_message], + status="completed", + text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), + input=[input_item_message], + messages=[ + OpenAIUserMessageParam(content="Name some towns in Ireland"), + OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"), + ], + ) + mock_responses_store.get_response_object.return_value = response + + model = "meta-llama/Llama-3.1-8B-Instruct" + instructions = "You are a geography expert. Provide concise answers." + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + await openai_responses_impl.create_openai_response( + input="Which is the largest?", model=model, instructions=instructions, previous_response_id="123" + ) + + # Verify + mock_inference_api.openai_chat_completion.assert_called_once() + call_args = mock_inference_api.openai_chat_completion.call_args + params = call_args.args[0] + sent_messages = params.messages + + # Check that instructions were prepended as a system message + assert len(sent_messages) == 4, sent_messages + assert sent_messages[0].role == "system" + assert sent_messages[0].content == instructions + + # Check the rest of the messages were converted correctly + assert sent_messages[1].role == "user" + assert sent_messages[1].content == "Name some towns in Ireland" + assert sent_messages[2].role == "assistant" + assert sent_messages[2].content == "Galway, Longford, Sligo" + assert sent_messages[3].role == "user" + assert sent_messages[3].content == "Which is the largest?" + + +async def test_create_openai_response_with_previous_response_instructions( + openai_responses_impl, mock_responses_store, mock_inference_api +): + """Test prepending instructions and previous response with instructions.""" + + input_item_message = OpenAIResponseMessage( + id="123", + content="Name some towns in Ireland", + role="user", + ) + response_output_message = OpenAIResponseMessage( + id="123", + content="Galway, Longford, Sligo", + status="completed", + role="assistant", + ) + response = _OpenAIResponseObjectWithInputAndMessages( + created_at=1, + id="resp_123", + model="fake_model", + output=[response_output_message], + status="completed", + text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), + input=[input_item_message], + messages=[ + OpenAIUserMessageParam(content="Name some towns in Ireland"), + OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"), + ], + instructions="You are a helpful assistant.", + ) + mock_responses_store.get_response_object.return_value = response + + model = "meta-llama/Llama-3.1-8B-Instruct" + instructions = "You are a geography expert. Provide concise answers." + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + await openai_responses_impl.create_openai_response( + input="Which is the largest?", model=model, instructions=instructions, previous_response_id="123" + ) + + # Verify + mock_inference_api.openai_chat_completion.assert_called_once() + call_args = mock_inference_api.openai_chat_completion.call_args + params = call_args.args[0] + sent_messages = params.messages + + # Check that instructions were prepended as a system message + # and that the previous response instructions were not carried over + assert len(sent_messages) == 4, sent_messages + assert sent_messages[0].role == "system" + assert sent_messages[0].content == instructions + + # Check the rest of the messages were converted correctly + assert sent_messages[1].role == "user" + assert sent_messages[1].content == "Name some towns in Ireland" + assert sent_messages[2].role == "assistant" + assert sent_messages[2].content == "Galway, Longford, Sligo" + assert sent_messages[3].role == "user" + assert sent_messages[3].content == "Which is the largest?" + + +async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store): + """Test that list_openai_response_input_items properly delegates to responses_store with correct parameters.""" + # Setup + response_id = "resp_123" + after = "msg_after" + before = "msg_before" + include = ["metadata"] + limit = 5 + order = Order.asc + + input_message = OpenAIResponseMessage( + id="msg_123", + content="Test message", + role="user", + ) + + expected_result = ListOpenAIResponseInputItem(data=[input_message]) + mock_responses_store.list_response_input_items.return_value = expected_result + + # Execute with all parameters to test delegation + result = await openai_responses_impl.list_openai_response_input_items( + response_id, after=after, before=before, include=include, limit=limit, order=order + ) + + # Verify all parameters are passed through correctly to the store + mock_responses_store.list_response_input_items.assert_called_once_with( + response_id, after, before, include, limit, order + ) + + # Verify the result is returned as-is from the store + assert result.object == "list" + assert len(result.data) == 1 + assert result.data[0].id == "msg_123" + + +async def test_responses_store_list_input_items_logic(): + """Test ResponsesStore list_response_input_items logic - mocks get_response_object to test actual ordering/limiting.""" + + # Create mock store and response store + mock_sql_store = AsyncMock() + backend_name = "sql_responses_test" + register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path="mock_db_path")}) + responses_store = ResponsesStore( + ResponsesStoreReference(backend=backend_name, table_name="responses"), policy=default_policy() + ) + responses_store.sql_store = mock_sql_store + + # Setup test data - multiple input items + input_items = [ + OpenAIResponseMessage(id="msg_1", content="First message", role="user"), + OpenAIResponseMessage(id="msg_2", content="Second message", role="user"), + OpenAIResponseMessage(id="msg_3", content="Third message", role="user"), + OpenAIResponseMessage(id="msg_4", content="Fourth message", role="user"), + ] + + response_with_input = _OpenAIResponseObjectWithInputAndMessages( + id="resp_123", + model="test_model", + created_at=1234567890, + object="response", + status="completed", + output=[], + text=OpenAIResponseText(format=(OpenAIResponseTextFormat(type="text"))), + input=input_items, + messages=[OpenAIUserMessageParam(content="First message")], + ) + + # Mock the get_response_object method to return our test data + mock_sql_store.fetch_one.return_value = {"response_object": response_with_input.model_dump()} + + # Test 1: Default behavior (no limit, desc order) + result = await responses_store.list_response_input_items("resp_123") + assert result.object == "list" + assert len(result.data) == 4 + # Should be reversed for desc order + assert result.data[0].id == "msg_4" + assert result.data[1].id == "msg_3" + assert result.data[2].id == "msg_2" + assert result.data[3].id == "msg_1" + + # Test 2: With limit=2, desc order + result = await responses_store.list_response_input_items("resp_123", limit=2, order=Order.desc) + assert result.object == "list" + assert len(result.data) == 2 + # Should be first 2 items in desc order + assert result.data[0].id == "msg_4" + assert result.data[1].id == "msg_3" + + # Test 3: With limit=2, asc order + result = await responses_store.list_response_input_items("resp_123", limit=2, order=Order.asc) + assert result.object == "list" + assert len(result.data) == 2 + # Should be first 2 items in original order (asc) + assert result.data[0].id == "msg_1" + assert result.data[1].id == "msg_2" + + # Test 4: Asc order without limit + result = await responses_store.list_response_input_items("resp_123", order=Order.asc) + assert result.object == "list" + assert len(result.data) == 4 + # Should be in original order (asc) + assert result.data[0].id == "msg_1" + assert result.data[1].id == "msg_2" + assert result.data[2].id == "msg_3" + assert result.data[3].id == "msg_4" + + # Test 5: Large limit (larger than available items) + result = await responses_store.list_response_input_items("resp_123", limit=10, order=Order.desc) + assert result.object == "list" + assert len(result.data) == 4 # Should return all available items + assert result.data[0].id == "msg_4" + + # Test 6: Zero limit edge case + result = await responses_store.list_response_input_items("resp_123", limit=0, order=Order.asc) + assert result.object == "list" + assert len(result.data) == 0 # Should return no items + + +async def test_store_response_uses_rehydrated_input_with_previous_response( + openai_responses_impl, mock_responses_store, mock_inference_api +): + """Test that _store_response uses the full re-hydrated input (including previous responses) + rather than just the original input when previous_response_id is provided.""" + + # Setup - Create a previous response that should be included in the stored input + previous_response = _OpenAIResponseObjectWithInputAndMessages( + id="resp-previous-123", + object="response", + created_at=1234567890, + model="meta-llama/Llama-3.1-8B-Instruct", + status="completed", + text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), + input=[ + OpenAIResponseMessage( + id="msg-prev-user", role="user", content=[OpenAIResponseInputMessageContentText(text="What is 2+2?")] + ) + ], + output=[ + OpenAIResponseMessage( + id="msg-prev-assistant", + role="assistant", + content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")], + ) + ], + messages=[ + OpenAIUserMessageParam(content="What is 2+2?"), + OpenAIAssistantMessageParam(content="2+2 equals 4."), + ], + ) + + mock_responses_store.get_response_object.return_value = previous_response + + current_input = "Now what is 3+3?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute - Create response with previous_response_id + result = await openai_responses_impl.create_openai_response( + input=current_input, + model=model, + previous_response_id="resp-previous-123", + store=True, + ) + + store_call_args = mock_responses_store.store_response_object.call_args + stored_input = store_call_args.kwargs["input"] + + # Verify that the stored input contains the full re-hydrated conversation: + # 1. Previous user message + # 2. Previous assistant response + # 3. Current user message + assert len(stored_input) == 3 + + assert stored_input[0].role == "user" + assert stored_input[0].content[0].text == "What is 2+2?" + + assert stored_input[1].role == "assistant" + assert stored_input[1].content[0].text == "2+2 equals 4." + + assert stored_input[2].role == "user" + assert stored_input[2].content == "Now what is 3+3?" + + # Verify the response itself is correct + assert result.model == model + assert result.status == "completed" + + +@patch("llama_stack.providers.utils.tools.mcp.list_mcp_tools") +async def test_reuse_mcp_tool_list( + mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api +): + """Test that mcp_list_tools can be reused where appropriate.""" + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + mock_list_mcp_tools.return_value = ListToolDefsResponse( + data=[ToolDef(name="test_tool", description="a test tool", input_schema={}, output_schema={})] + ) + + res1 = await openai_responses_impl.create_openai_response( + input="What is 2+2?", + model="meta-llama/Llama-3.1-8B-Instruct", + store=True, + tools=[ + OpenAIResponseInputToolFunction(name="fake", parameters=None), + OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"), + ], + ) + args = mock_responses_store.store_response_object.call_args + data = args.kwargs["response_object"].model_dump() + data["input"] = [input_item.model_dump() for input_item in args.kwargs["input"]] + data["messages"] = [msg.model_dump() for msg in args.kwargs["messages"]] + stored = _OpenAIResponseObjectWithInputAndMessages(**data) + mock_responses_store.get_response_object.return_value = stored + + res2 = await openai_responses_impl.create_openai_response( + previous_response_id=res1.id, + input="Now what is 3+3?", + model="meta-llama/Llama-3.1-8B-Instruct", + store=True, + tools=[ + OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"), + ], + ) + assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2 + second_call = mock_inference_api.openai_chat_completion.call_args_list[1] + second_params = second_call.args[0] + tools_seen = second_params.tools + assert len(tools_seen) == 1 + assert tools_seen[0]["function"]["name"] == "test_tool" + assert tools_seen[0]["function"]["description"] == "a test tool" + + assert mock_list_mcp_tools.call_count == 1 + listings = [obj for obj in res2.output if obj.type == "mcp_list_tools"] + assert len(listings) == 1 + assert listings[0].server_label == "alabel" + assert len(listings[0].tools) == 1 + assert listings[0].tools[0].name == "test_tool" + + +@pytest.mark.parametrize( + "text_format, response_format", + [ + (OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), None), + ( + OpenAIResponseText(format=OpenAIResponseTextFormat(name="Test", schema={"foo": "bar"}, type="json_schema")), + OpenAIResponseFormatJSONSchema(json_schema=OpenAIJSONSchema(name="Test", schema={"foo": "bar"})), + ), + (OpenAIResponseText(format=OpenAIResponseTextFormat(type="json_object")), OpenAIResponseFormatJSONObject()), + # ensure text param with no format specified defaults to None + (OpenAIResponseText(format=None), None), + # ensure text param of None defaults to None + (None, None), + ], +) +async def test_create_openai_response_with_text_format( + openai_responses_impl, mock_inference_api, text_format, response_format +): + """Test creating Responses with text formats.""" + # Setup + input_text = "How hot it is in San Francisco today?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + _result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + text=text_format, + ) + + # Verify + first_call = mock_inference_api.openai_chat_completion.call_args_list[0] + first_params = first_call.args[0] + assert first_params.messages[0].content == input_text + assert first_params.response_format == response_format + + +async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api): + """Test creating an OpenAI response with an invalid text format.""" + # Setup + input_text = "How hot it is in San Francisco today?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + # Execute + with pytest.raises(ValueError): + _result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + text=OpenAIResponseText(format={"type": "invalid"}), + ) + + +async def test_create_openai_response_with_output_types_as_input( + openai_responses_impl, mock_inference_api, mock_responses_store +): + """Test that response outputs can be used as inputs in multi-turn conversations. + + Before adding OpenAIResponseOutput types to OpenAIResponseInput, + creating a _OpenAIResponseObjectWithInputAndMessages with some output types + in the input field would fail with a Pydantic ValidationError. + + This test simulates storing a response where the input contains output message + types (MCP calls, function calls), which happens in multi-turn conversations. + """ + model = "meta-llama/Llama-3.1-8B-Instruct" + + # Mock the inference response + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Create a response with store=True to trigger the storage path + result = await openai_responses_impl.create_openai_response( + input="What's the weather?", + model=model, + stream=True, + temperature=0.1, + store=True, + ) + + # Consume the stream + _ = [chunk async for chunk in result] + + # Verify store was called + assert mock_responses_store.store_response_object.called + + # Get the stored data + store_call_args = mock_responses_store.store_response_object.call_args + stored_response = store_call_args.kwargs["response_object"] + + # Now simulate a multi-turn conversation where outputs become inputs + input_with_output_types = [ + OpenAIResponseMessage(role="user", content="What's the weather?", name=None), + # These output types need to be valid OpenAIResponseInput + OpenAIResponseOutputMessageFunctionToolCall( + call_id="call_123", + name="get_weather", + arguments='{"city": "Tokyo"}', + type="function_call", + ), + OpenAIResponseOutputMessageMCPCall( + id="mcp_456", + type="mcp_call", + server_label="weather_server", + name="get_temperature", + arguments='{"location": "Tokyo"}', + output="25°C", + ), + ] + + # This simulates storing a response in a multi-turn conversation + # where previous outputs are included in the input. + stored_with_outputs = _OpenAIResponseObjectWithInputAndMessages( + id=stored_response.id, + created_at=stored_response.created_at, + model=stored_response.model, + status=stored_response.status, + output=stored_response.output, + input=input_with_output_types, # This will trigger Pydantic validation + messages=None, + ) + + assert stored_with_outputs.input == input_with_output_types + assert len(stored_with_outputs.input) == 3 diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses_conversations.py b/tests/unit/providers/agents/meta_reference/test_openai_responses_conversations.py new file mode 100644 index 000000000..c2c113c1b --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses_conversations.py @@ -0,0 +1,249 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import pytest + +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseMessage, + OpenAIResponseObject, + OpenAIResponseObjectStreamResponseCompleted, + OpenAIResponseObjectStreamResponseOutputItemDone, + OpenAIResponseOutputMessageContentOutputText, +) +from llama_stack.apis.common.errors import ( + ConversationNotFoundError, + InvalidConversationIdError, +) +from llama_stack.apis.conversations.conversations import ( + ConversationItemList, +) + +# Import existing fixtures from the main responses test file +pytest_plugins = ["tests.unit.providers.agents.meta_reference.test_openai_responses"] + +from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( + OpenAIResponsesImpl, +) + + +@pytest.fixture +def responses_impl_with_conversations( + mock_inference_api, + mock_tool_groups_api, + mock_tool_runtime_api, + mock_responses_store, + mock_vector_io_api, + mock_conversations_api, + mock_safety_api, +): + """Create OpenAIResponsesImpl instance with conversations API.""" + return OpenAIResponsesImpl( + inference_api=mock_inference_api, + tool_groups_api=mock_tool_groups_api, + tool_runtime_api=mock_tool_runtime_api, + responses_store=mock_responses_store, + vector_io_api=mock_vector_io_api, + conversations_api=mock_conversations_api, + safety_api=mock_safety_api, + ) + + +class TestConversationValidation: + """Test conversation ID validation logic.""" + + async def test_nonexistent_conversation_raises_error( + self, responses_impl_with_conversations, mock_conversations_api + ): + """Test that ConversationNotFoundError is raised for non-existent conversation.""" + conv_id = "conv_nonexistent" + + # Mock conversation not found + mock_conversations_api.list_items.side_effect = ConversationNotFoundError("conv_nonexistent") + + with pytest.raises(ConversationNotFoundError): + await responses_impl_with_conversations.create_openai_response( + input="Hello", model="test-model", conversation=conv_id, stream=False + ) + + +class TestMessageSyncing: + """Test message syncing to conversations.""" + + async def test_sync_response_to_conversation_simple( + self, responses_impl_with_conversations, mock_conversations_api + ): + """Test syncing simple response to conversation.""" + conv_id = "conv_test123" + input_text = "What are the 5 Ds of dodgeball?" + + # Output items (what the model generated) + output_items = [ + OpenAIResponseMessage( + id="msg_response", + content=[ + OpenAIResponseOutputMessageContentOutputText( + text="The 5 Ds are: Dodge, Duck, Dip, Dive, and Dodge.", type="output_text", annotations=[] + ) + ], + role="assistant", + status="completed", + type="message", + ) + ] + + await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_text, output_items) + + # should call add_items with user input and assistant response + mock_conversations_api.add_items.assert_called_once() + call_args = mock_conversations_api.add_items.call_args + + assert call_args[0][0] == conv_id # conversation_id + items = call_args[0][1] # conversation_items + + assert len(items) == 2 + # User message + assert items[0].type == "message" + assert items[0].role == "user" + assert items[0].content[0].type == "input_text" + assert items[0].content[0].text == input_text + + # Assistant message + assert items[1].type == "message" + assert items[1].role == "assistant" + + async def test_sync_response_to_conversation_api_error( + self, responses_impl_with_conversations, mock_conversations_api + ): + mock_conversations_api.add_items.side_effect = Exception("API Error") + output_items = [] + + # matching the behavior of OpenAI here + with pytest.raises(Exception, match="API Error"): + await responses_impl_with_conversations._sync_response_to_conversation( + "conv_test123", "Hello", output_items + ) + + async def test_sync_with_list_input(self, responses_impl_with_conversations, mock_conversations_api): + """Test syncing with list of input messages.""" + conv_id = "conv_test123" + input_messages = [ + OpenAIResponseMessage(role="user", content=[{"type": "input_text", "text": "First message"}]), + ] + output_items = [ + OpenAIResponseMessage( + id="msg_response", + content=[OpenAIResponseOutputMessageContentOutputText(text="Response", type="output_text")], + role="assistant", + status="completed", + type="message", + ) + ] + + await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_messages, output_items) + + mock_conversations_api.add_items.assert_called_once() + call_args = mock_conversations_api.add_items.call_args + + items = call_args[0][1] + # Should have input message + output message + assert len(items) == 2 + + +class TestIntegrationWorkflow: + """Integration tests for the full conversation workflow.""" + + async def test_create_response_with_valid_conversation( + self, responses_impl_with_conversations, mock_conversations_api + ): + """Test creating a response with a valid conversation parameter.""" + mock_conversations_api.list_items.return_value = ConversationItemList( + data=[], first_id=None, has_more=False, last_id=None, object="list" + ) + + async def mock_streaming_response(*args, **kwargs): + message_item = OpenAIResponseMessage( + id="msg_response", + content=[ + OpenAIResponseOutputMessageContentOutputText( + text="Test response", type="output_text", annotations=[] + ) + ], + role="assistant", + status="completed", + type="message", + ) + + # Emit output_item.done event first (needed for conversation sync) + yield OpenAIResponseObjectStreamResponseOutputItemDone( + response_id="resp_test123", + item=message_item, + output_index=0, + sequence_number=1, + type="response.output_item.done", + ) + + # Then emit response.completed + mock_response = OpenAIResponseObject( + id="resp_test123", + created_at=1234567890, + model="test-model", + object="response", + output=[message_item], + status="completed", + ) + + yield OpenAIResponseObjectStreamResponseCompleted(response=mock_response, type="response.completed") + + responses_impl_with_conversations._create_streaming_response = mock_streaming_response + + input_text = "Hello, how are you?" + conversation_id = "conv_test123" + + response = await responses_impl_with_conversations.create_openai_response( + input=input_text, model="test-model", conversation=conversation_id, stream=False + ) + + assert response is not None + assert response.id == "resp_test123" + + # Note: conversation sync happens inside _create_streaming_response, + # which we're mocking here, so we can't test it in this unit test. + # The sync logic is tested separately in TestMessageSyncing. + + async def test_create_response_with_invalid_conversation_id(self, responses_impl_with_conversations): + """Test creating a response with an invalid conversation ID.""" + with pytest.raises(InvalidConversationIdError) as exc_info: + await responses_impl_with_conversations.create_openai_response( + input="Hello", model="test-model", conversation="invalid_id", stream=False + ) + + assert "Expected an ID that begins with 'conv_'" in str(exc_info.value) + + async def test_create_response_with_nonexistent_conversation( + self, responses_impl_with_conversations, mock_conversations_api + ): + """Test creating a response with a non-existent conversation.""" + mock_conversations_api.list_items.side_effect = ConversationNotFoundError("conv_nonexistent") + + with pytest.raises(ConversationNotFoundError) as exc_info: + await responses_impl_with_conversations.create_openai_response( + input="Hello", model="test-model", conversation="conv_nonexistent", stream=False + ) + + assert "not found" in str(exc_info.value) + + async def test_conversation_and_previous_response_id( + self, responses_impl_with_conversations, mock_conversations_api, mock_responses_store + ): + with pytest.raises(ValueError) as exc_info: + await responses_impl_with_conversations.create_openai_response( + input="test", model="test", conversation="conv_123", previous_response_id="resp_123" + ) + + assert "Mutually exclusive parameters" in str(exc_info.value) + assert "previous_response_id" in str(exc_info.value) + assert "conversation" in str(exc_info.value) diff --git a/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py b/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py new file mode 100644 index 000000000..2698b88c8 --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/test_response_conversion_utils.py @@ -0,0 +1,367 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import pytest + +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseAnnotationFileCitation, + OpenAIResponseInputFunctionToolCallOutput, + OpenAIResponseInputMessageContentImage, + OpenAIResponseInputMessageContentText, + OpenAIResponseInputToolFunction, + OpenAIResponseInputToolWebSearch, + OpenAIResponseMessage, + OpenAIResponseOutputMessageContentOutputText, + OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseText, + OpenAIResponseTextFormat, +) +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + OpenAIChoice, + OpenAIDeveloperMessageParam, + OpenAIResponseFormatJSONObject, + OpenAIResponseFormatJSONSchema, + OpenAIResponseFormatText, + OpenAISystemMessageParam, + OpenAIToolMessageParam, + OpenAIUserMessageParam, +) +from llama_stack.providers.inline.agents.meta_reference.responses.utils import ( + _extract_citations_from_text, + convert_chat_choice_to_response_message, + convert_response_content_to_chat_content, + convert_response_input_to_chat_messages, + convert_response_text_to_chat_response_format, + get_message_type_by_role, + is_function_tool_call, +) + + +class TestConvertChatChoiceToResponseMessage: + async def test_convert_string_content(self): + choice = OpenAIChoice( + message=OpenAIAssistantMessageParam(content="Test message"), + finish_reason="stop", + index=0, + ) + + result = await convert_chat_choice_to_response_message(choice) + + assert result.role == "assistant" + assert result.status == "completed" + assert len(result.content) == 1 + assert isinstance(result.content[0], OpenAIResponseOutputMessageContentOutputText) + assert result.content[0].text == "Test message" + + async def test_convert_text_param_content(self): + choice = OpenAIChoice( + message=OpenAIAssistantMessageParam( + content=[OpenAIChatCompletionContentPartTextParam(text="Test text param")] + ), + finish_reason="stop", + index=0, + ) + + with pytest.raises(ValueError) as exc_info: + await convert_chat_choice_to_response_message(choice) + + assert "does not yet support output content type" in str(exc_info.value) + + +class TestConvertResponseContentToChatContent: + async def test_convert_string_content(self): + result = await convert_response_content_to_chat_content("Simple string") + assert result == "Simple string" + + async def test_convert_text_content_parts(self): + content = [ + OpenAIResponseInputMessageContentText(text="First part"), + OpenAIResponseOutputMessageContentOutputText(text="Second part"), + ] + + result = await convert_response_content_to_chat_content(content) + + assert len(result) == 2 + assert isinstance(result[0], OpenAIChatCompletionContentPartTextParam) + assert result[0].text == "First part" + assert isinstance(result[1], OpenAIChatCompletionContentPartTextParam) + assert result[1].text == "Second part" + + async def test_convert_image_content(self): + content = [OpenAIResponseInputMessageContentImage(image_url="https://example.com/image.jpg", detail="high")] + + result = await convert_response_content_to_chat_content(content) + + assert len(result) == 1 + assert isinstance(result[0], OpenAIChatCompletionContentPartImageParam) + assert result[0].image_url.url == "https://example.com/image.jpg" + assert result[0].image_url.detail == "high" + + +class TestConvertResponseInputToChatMessages: + async def test_convert_string_input(self): + result = await convert_response_input_to_chat_messages("User message") + + assert len(result) == 1 + assert isinstance(result[0], OpenAIUserMessageParam) + assert result[0].content == "User message" + + async def test_convert_function_tool_call_output(self): + input_items = [ + OpenAIResponseOutputMessageFunctionToolCall( + call_id="call_123", + name="test_function", + arguments='{"param": "value"}', + ), + OpenAIResponseInputFunctionToolCallOutput( + output="Tool output", + call_id="call_123", + ), + ] + + result = await convert_response_input_to_chat_messages(input_items) + + assert len(result) == 2 + assert isinstance(result[0], OpenAIAssistantMessageParam) + assert result[0].tool_calls[0].id == "call_123" + assert result[0].tool_calls[0].function.name == "test_function" + assert result[0].tool_calls[0].function.arguments == '{"param": "value"}' + assert isinstance(result[1], OpenAIToolMessageParam) + assert result[1].content == "Tool output" + assert result[1].tool_call_id == "call_123" + + async def test_convert_function_tool_call(self): + input_items = [ + OpenAIResponseOutputMessageFunctionToolCall( + call_id="call_456", + name="test_function", + arguments='{"param": "value"}', + ) + ] + + result = await convert_response_input_to_chat_messages(input_items) + + assert len(result) == 1 + assert isinstance(result[0], OpenAIAssistantMessageParam) + assert len(result[0].tool_calls) == 1 + assert result[0].tool_calls[0].id == "call_456" + assert result[0].tool_calls[0].function.name == "test_function" + assert result[0].tool_calls[0].function.arguments == '{"param": "value"}' + + async def test_convert_function_call_ordering(self): + input_items = [ + OpenAIResponseOutputMessageFunctionToolCall( + call_id="call_123", + name="test_function_a", + arguments='{"param": "value"}', + ), + OpenAIResponseOutputMessageFunctionToolCall( + call_id="call_456", + name="test_function_b", + arguments='{"param": "value"}', + ), + OpenAIResponseInputFunctionToolCallOutput( + output="AAA", + call_id="call_123", + ), + OpenAIResponseInputFunctionToolCallOutput( + output="BBB", + call_id="call_456", + ), + ] + + result = await convert_response_input_to_chat_messages(input_items) + assert len(result) == 4 + assert isinstance(result[0], OpenAIAssistantMessageParam) + assert len(result[0].tool_calls) == 1 + assert result[0].tool_calls[0].id == "call_123" + assert result[0].tool_calls[0].function.name == "test_function_a" + assert result[0].tool_calls[0].function.arguments == '{"param": "value"}' + assert isinstance(result[1], OpenAIToolMessageParam) + assert result[1].content == "AAA" + assert result[1].tool_call_id == "call_123" + assert isinstance(result[2], OpenAIAssistantMessageParam) + assert len(result[2].tool_calls) == 1 + assert result[2].tool_calls[0].id == "call_456" + assert result[2].tool_calls[0].function.name == "test_function_b" + assert result[2].tool_calls[0].function.arguments == '{"param": "value"}' + assert isinstance(result[3], OpenAIToolMessageParam) + assert result[3].content == "BBB" + assert result[3].tool_call_id == "call_456" + + async def test_convert_response_message(self): + input_items = [ + OpenAIResponseMessage( + role="user", + content=[OpenAIResponseInputMessageContentText(text="User text")], + ) + ] + + result = await convert_response_input_to_chat_messages(input_items) + + assert len(result) == 1 + assert isinstance(result[0], OpenAIUserMessageParam) + # Content should be converted to chat content format + assert len(result[0].content) == 1 + assert result[0].content[0].text == "User text" + + +class TestConvertResponseTextToChatResponseFormat: + async def test_convert_text_format(self): + text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) + result = await convert_response_text_to_chat_response_format(text) + + assert isinstance(result, OpenAIResponseFormatText) + assert result.type == "text" + + async def test_convert_json_object_format(self): + text = OpenAIResponseText(format={"type": "json_object"}) + result = await convert_response_text_to_chat_response_format(text) + + assert isinstance(result, OpenAIResponseFormatJSONObject) + + async def test_convert_json_schema_format(self): + schema_def = {"type": "object", "properties": {"test": {"type": "string"}}} + text = OpenAIResponseText( + format={ + "type": "json_schema", + "name": "test_schema", + "schema": schema_def, + } + ) + result = await convert_response_text_to_chat_response_format(text) + + assert isinstance(result, OpenAIResponseFormatJSONSchema) + assert result.json_schema["name"] == "test_schema" + assert result.json_schema["schema"] == schema_def + + async def test_default_text_format(self): + text = OpenAIResponseText() + result = await convert_response_text_to_chat_response_format(text) + + assert isinstance(result, OpenAIResponseFormatText) + assert result.type == "text" + + +class TestGetMessageTypeByRole: + async def test_user_role(self): + result = await get_message_type_by_role("user") + assert result == OpenAIUserMessageParam + + async def test_system_role(self): + result = await get_message_type_by_role("system") + assert result == OpenAISystemMessageParam + + async def test_assistant_role(self): + result = await get_message_type_by_role("assistant") + assert result == OpenAIAssistantMessageParam + + async def test_developer_role(self): + result = await get_message_type_by_role("developer") + assert result == OpenAIDeveloperMessageParam + + async def test_unknown_role(self): + result = await get_message_type_by_role("unknown") + assert result is None + + +class TestIsFunctionToolCall: + def test_is_function_tool_call_true(self): + tool_call = OpenAIChatCompletionToolCall( + index=0, + id="call_123", + function=OpenAIChatCompletionToolCallFunction( + name="test_function", + arguments="{}", + ), + ) + tools = [ + OpenAIResponseInputToolFunction( + type="function", name="test_function", parameters={"type": "object", "properties": {}} + ), + OpenAIResponseInputToolWebSearch(type="web_search"), + ] + + result = is_function_tool_call(tool_call, tools) + assert result is True + + def test_is_function_tool_call_false_different_name(self): + tool_call = OpenAIChatCompletionToolCall( + index=0, + id="call_123", + function=OpenAIChatCompletionToolCallFunction( + name="other_function", + arguments="{}", + ), + ) + tools = [ + OpenAIResponseInputToolFunction( + type="function", name="test_function", parameters={"type": "object", "properties": {}} + ), + ] + + result = is_function_tool_call(tool_call, tools) + assert result is False + + def test_is_function_tool_call_false_no_function(self): + tool_call = OpenAIChatCompletionToolCall( + index=0, + id="call_123", + function=None, + ) + tools = [ + OpenAIResponseInputToolFunction( + type="function", name="test_function", parameters={"type": "object", "properties": {}} + ), + ] + + result = is_function_tool_call(tool_call, tools) + assert result is False + + def test_is_function_tool_call_false_wrong_type(self): + tool_call = OpenAIChatCompletionToolCall( + index=0, + id="call_123", + function=OpenAIChatCompletionToolCallFunction( + name="web_search", + arguments="{}", + ), + ) + tools = [ + OpenAIResponseInputToolWebSearch(type="web_search"), + ] + + result = is_function_tool_call(tool_call, tools) + assert result is False + + +class TestExtractCitationsFromText: + def test_extract_citations_and_annotations(self): + text = "Start [not-a-file]. New source <|file-abc123|>. " + text += "Other source <|file-def456|>? Repeat source <|file-abc123|>! No citation." + file_mapping = {"file-abc123": "doc1.pdf", "file-def456": "doc2.txt"} + + annotations, cleaned_text = _extract_citations_from_text(text, file_mapping) + + expected_annotations = [ + OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=30), + OpenAIResponseAnnotationFileCitation(file_id="file-def456", filename="doc2.txt", index=44), + OpenAIResponseAnnotationFileCitation(file_id="file-abc123", filename="doc1.pdf", index=59), + ] + expected_clean_text = "Start [not-a-file]. New source. Other source? Repeat source! No citation." + + assert cleaned_text == expected_clean_text + assert annotations == expected_annotations + # OpenAI cites at the end of the sentence + assert cleaned_text[expected_annotations[0].index] == "." + assert cleaned_text[expected_annotations[1].index] == "?" + assert cleaned_text[expected_annotations[2].index] == "!" diff --git a/tests/unit/providers/agents/meta_reference/test_response_tool_context.py b/tests/unit/providers/agents/meta_reference/test_response_tool_context.py new file mode 100644 index 000000000..e966ad41e --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/test_response_tool_context.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.apis.agents.openai_responses import ( + MCPListToolsTool, + OpenAIResponseInputToolFileSearch, + OpenAIResponseInputToolFunction, + OpenAIResponseInputToolMCP, + OpenAIResponseInputToolWebSearch, + OpenAIResponseObject, + OpenAIResponseOutputMessageMCPListTools, + OpenAIResponseToolMCP, +) +from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext + + +class TestToolContext: + def test_no_tools(self): + tools = [] + context = ToolContext(tools) + previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="") + context.recover_tools_from_previous_response(previous_response) + + assert len(context.tools_to_process) == 0 + assert len(context.previous_tools) == 0 + assert len(context.previous_tool_listings) == 0 + + def test_no_previous_tools(self): + tools = [ + OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]), + OpenAIResponseInputToolMCP(server_label="label", server_url="url"), + ] + context = ToolContext(tools) + previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="") + context.recover_tools_from_previous_response(previous_response) + + assert len(context.tools_to_process) == 2 + assert len(context.previous_tools) == 0 + assert len(context.previous_tool_listings) == 0 + + def test_reusable_server(self): + tools = [ + OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]), + OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"), + ] + context = ToolContext(tools) + output = [ + OpenAIResponseOutputMessageMCPListTools( + id="test", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})] + ) + ] + previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="") + previous_response.tools = [ + OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]), + OpenAIResponseToolMCP(server_label="alabel"), + ] + context.recover_tools_from_previous_response(previous_response) + + assert len(context.tools_to_process) == 1 + assert context.tools_to_process[0].type == "file_search" + assert len(context.previous_tools) == 1 + assert context.previous_tools["test_tool"].server_label == "alabel" + assert context.previous_tools["test_tool"].server_url == "aurl" + assert len(context.previous_tool_listings) == 1 + assert len(context.previous_tool_listings[0].tools) == 1 + assert context.previous_tool_listings[0].server_label == "alabel" + + def test_multiple_reusable_servers(self): + tools = [ + OpenAIResponseInputToolFunction(name="fake", parameters=None), + OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"), + OpenAIResponseInputToolWebSearch(), + OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"), + ] + context = ToolContext(tools) + output = [ + OpenAIResponseOutputMessageMCPListTools( + id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})] + ), + OpenAIResponseOutputMessageMCPListTools( + id="test2", + server_label="anotherlabel", + tools=[MCPListToolsTool(name="some_other_tool", input_schema={})], + ), + ] + previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="") + previous_response.tools = [ + OpenAIResponseInputToolFunction(name="fake", parameters=None), + OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"), + OpenAIResponseInputToolWebSearch(type="web_search"), + OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"), + ] + context.recover_tools_from_previous_response(previous_response) + + assert len(context.tools_to_process) == 2 + assert context.tools_to_process[0].type == "function" + assert context.tools_to_process[1].type == "web_search" + assert len(context.previous_tools) == 2 + assert context.previous_tools["test_tool"].server_label == "alabel" + assert context.previous_tools["test_tool"].server_url == "aurl" + assert context.previous_tools["some_other_tool"].server_label == "anotherlabel" + assert context.previous_tools["some_other_tool"].server_url == "anotherurl" + assert len(context.previous_tool_listings) == 2 + assert len(context.previous_tool_listings[0].tools) == 1 + assert context.previous_tool_listings[0].server_label == "alabel" + assert len(context.previous_tool_listings[1].tools) == 1 + assert context.previous_tool_listings[1].server_label == "anotherlabel" + + def test_multiple_servers_only_one_reusable(self): + tools = [ + OpenAIResponseInputToolFunction(name="fake", parameters=None), + OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"), + OpenAIResponseInputToolWebSearch(type="web_search"), + OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"), + ] + context = ToolContext(tools) + output = [ + OpenAIResponseOutputMessageMCPListTools( + id="test2", + server_label="anotherlabel", + tools=[MCPListToolsTool(name="some_other_tool", input_schema={})], + ) + ] + previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="") + previous_response.tools = [ + OpenAIResponseInputToolFunction(name="fake", parameters=None), + OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"), + OpenAIResponseInputToolWebSearch(type="web_search"), + ] + context.recover_tools_from_previous_response(previous_response) + + assert len(context.tools_to_process) == 3 + assert context.tools_to_process[0].type == "function" + assert context.tools_to_process[1].type == "web_search" + assert context.tools_to_process[2].type == "mcp" + assert len(context.previous_tools) == 1 + assert context.previous_tools["some_other_tool"].server_label == "anotherlabel" + assert context.previous_tools["some_other_tool"].server_url == "anotherurl" + assert len(context.previous_tool_listings) == 1 + assert len(context.previous_tool_listings[0].tools) == 1 + assert context.previous_tool_listings[0].server_label == "anotherlabel" + + def test_mismatched_allowed_tools(self): + tools = [ + OpenAIResponseInputToolFunction(name="fake", parameters=None), + OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"), + OpenAIResponseInputToolWebSearch(type="web_search"), + OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl", allowed_tools=["test_tool_2"]), + ] + context = ToolContext(tools) + output = [ + OpenAIResponseOutputMessageMCPListTools( + id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool_1", input_schema={})] + ), + OpenAIResponseOutputMessageMCPListTools( + id="test2", + server_label="anotherlabel", + tools=[MCPListToolsTool(name="some_other_tool", input_schema={})], + ), + ] + previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="") + previous_response.tools = [ + OpenAIResponseInputToolFunction(name="fake", parameters=None), + OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"), + OpenAIResponseInputToolWebSearch(type="web_search"), + OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"), + ] + context.recover_tools_from_previous_response(previous_response) + + assert len(context.tools_to_process) == 3 + assert context.tools_to_process[0].type == "function" + assert context.tools_to_process[1].type == "web_search" + assert context.tools_to_process[2].type == "mcp" + assert len(context.previous_tools) == 1 + assert context.previous_tools["some_other_tool"].server_label == "anotherlabel" + assert context.previous_tools["some_other_tool"].server_url == "anotherurl" + assert len(context.previous_tool_listings) == 1 + assert len(context.previous_tool_listings[0].tools) == 1 + assert context.previous_tool_listings[0].server_label == "anotherlabel" diff --git a/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py b/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py new file mode 100644 index 000000000..9c5cc853c --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock + +import pytest + +from llama_stack.apis.agents.agents import ResponseGuardrailSpec +from llama_stack.apis.safety import ModerationObject, ModerationObjectResults +from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( + OpenAIResponsesImpl, +) +from llama_stack.providers.inline.agents.meta_reference.responses.utils import ( + extract_guardrail_ids, + run_guardrails, +) + + +@pytest.fixture +def mock_apis(): + """Create mock APIs for testing.""" + return { + "inference_api": AsyncMock(), + "tool_groups_api": AsyncMock(), + "tool_runtime_api": AsyncMock(), + "responses_store": AsyncMock(), + "vector_io_api": AsyncMock(), + "conversations_api": AsyncMock(), + "safety_api": AsyncMock(), + } + + +@pytest.fixture +def responses_impl(mock_apis): + """Create OpenAIResponsesImpl instance with mocked dependencies.""" + return OpenAIResponsesImpl(**mock_apis) + + +def test_extract_guardrail_ids_from_strings(responses_impl): + """Test extraction from simple string guardrail IDs.""" + guardrails = ["llama-guard", "content-filter", "nsfw-detector"] + result = extract_guardrail_ids(guardrails) + assert result == ["llama-guard", "content-filter", "nsfw-detector"] + + +def test_extract_guardrail_ids_from_objects(responses_impl): + """Test extraction from ResponseGuardrailSpec objects.""" + guardrails = [ + ResponseGuardrailSpec(type="llama-guard"), + ResponseGuardrailSpec(type="content-filter"), + ] + result = extract_guardrail_ids(guardrails) + assert result == ["llama-guard", "content-filter"] + + +def test_extract_guardrail_ids_mixed_formats(responses_impl): + """Test extraction from mixed string and object formats.""" + guardrails = [ + "llama-guard", + ResponseGuardrailSpec(type="content-filter"), + "nsfw-detector", + ] + result = extract_guardrail_ids(guardrails) + assert result == ["llama-guard", "content-filter", "nsfw-detector"] + + +def test_extract_guardrail_ids_none_input(responses_impl): + """Test extraction with None input.""" + result = extract_guardrail_ids(None) + assert result == [] + + +def test_extract_guardrail_ids_empty_list(responses_impl): + """Test extraction with empty list.""" + result = extract_guardrail_ids([]) + assert result == [] + + +def test_extract_guardrail_ids_unknown_format(responses_impl): + """Test extraction with unknown guardrail format raises ValueError.""" + # Create an object that's neither string nor ResponseGuardrailSpec + unknown_object = {"invalid": "format"} # Plain dict, not ResponseGuardrailSpec + guardrails = ["valid-guardrail", unknown_object, "another-guardrail"] + with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"): + extract_guardrail_ids(guardrails) + + +@pytest.fixture +def mock_safety_api(): + """Create mock safety API for guardrails testing.""" + safety_api = AsyncMock() + # Mock the routing table and shields list for guardrails lookup + safety_api.routing_table = AsyncMock() + shield = AsyncMock() + shield.identifier = "llama-guard" + shield.provider_resource_id = "llama-guard-model" + safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield]) + return safety_api + + +async def test_run_guardrails_no_violation(mock_safety_api): + """Test guardrails validation with no violations.""" + text = "Hello world" + guardrail_ids = ["llama-guard"] + + # Mock moderation to return non-flagged content + unflagged_result = ModerationObjectResults(flagged=False, categories={"violence": False}) + mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[unflagged_result]) + mock_safety_api.run_moderation.return_value = mock_moderation_object + + result = await run_guardrails(mock_safety_api, text, guardrail_ids) + + assert result is None + # Verify run_moderation was called with the correct model + mock_safety_api.run_moderation.assert_called_once() + call_args = mock_safety_api.run_moderation.call_args + assert call_args[1]["model"] == "llama-guard-model" + + +async def test_run_guardrails_with_violation(mock_safety_api): + """Test guardrails validation with safety violation.""" + text = "Harmful content" + guardrail_ids = ["llama-guard"] + + # Mock moderation to return flagged content + flagged_result = ModerationObjectResults( + flagged=True, + categories={"violence": True}, + user_message="Content flagged by moderation", + metadata={"violation_type": ["S1"]}, + ) + mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result]) + mock_safety_api.run_moderation.return_value = mock_moderation_object + + result = await run_guardrails(mock_safety_api, text, guardrail_ids) + + assert result == "Content flagged by moderation (flagged for: violence) (violation type: S1)" + + +async def test_run_guardrails_empty_inputs(mock_safety_api): + """Test guardrails validation with empty inputs.""" + # Test empty guardrail_ids + result = await run_guardrails(mock_safety_api, "test", []) + assert result is None + + # Test empty text + result = await run_guardrails(mock_safety_api, "", ["llama-guard"]) + assert result is None + + # Test both empty + result = await run_guardrails(mock_safety_api, "", []) + assert result is None