diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 1507a55c8..486ac9351 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -17,6 +17,8 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseOutputMessageMCPCall, + OpenAIResponseOutputMessageMCPListTools, OpenAIResponseText, ) from llama_stack.apis.inference import ( @@ -117,6 +119,25 @@ async def convert_response_input_to_chat_messages( ), ) messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) + elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall): + tool_call = OpenAIChatCompletionToolCall( + index=0, + id=input_item.id, + function=OpenAIChatCompletionToolCallFunction( + name=input_item.name, + arguments=input_item.arguments, + ), + ) + messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) + messages.append( + OpenAIToolMessageParam( + content=input_item.output, + tool_call_id=input_item.id, + ) + ) + elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools): + # the tool list will be handled separately + pass else: content = await convert_response_content_to_chat_content(input_item.content) message_type = await get_message_type_by_role(input_item.role) diff --git a/tests/integration/non_ci/responses/test_basic_responses.py b/tests/integration/non_ci/responses/test_basic_responses.py index a8106e593..17d50d348 100644 --- a/tests/integration/non_ci/responses/test_basic_responses.py +++ b/tests/integration/non_ci/responses/test_basic_responses.py @@ -7,8 +7,9 @@ import time import pytest -from fixtures.test_cases import basic_test_cases, image_test_cases, multi_turn_image_test_cases, multi_turn_test_cases -from streaming_assertions import StreamingValidator + +from .fixtures.test_cases import basic_test_cases, image_test_cases, multi_turn_image_test_cases, multi_turn_test_cases +from .streaming_assertions import StreamingValidator @pytest.mark.parametrize("case", basic_test_cases) diff --git a/tests/integration/non_ci/responses/test_tool_responses.py b/tests/integration/non_ci/responses/test_tool_responses.py index 33d109863..494b89226 100644 --- a/tests/integration/non_ci/responses/test_tool_responses.py +++ b/tests/integration/non_ci/responses/test_tool_responses.py @@ -10,7 +10,12 @@ import os import httpx import openai import pytest -from fixtures.test_cases import ( + +from llama_stack import LlamaStackAsLibraryClient +from llama_stack.core.datatypes import AuthenticationRequiredError +from tests.common.mcp import dependency_tools, make_mcp_server + +from .fixtures.test_cases import ( custom_tool_test_cases, file_search_test_cases, mcp_tool_test_cases, @@ -18,12 +23,8 @@ from fixtures.test_cases import ( multi_turn_tool_execution_test_cases, web_search_test_cases, ) -from helpers import new_vector_store, setup_mcp_tools, upload_file, wait_for_file_attachment -from streaming_assertions import StreamingValidator - -from llama_stack import LlamaStackAsLibraryClient -from llama_stack.core.datatypes import AuthenticationRequiredError -from tests.common.mcp import dependency_tools, make_mcp_server +from .helpers import new_vector_store, setup_mcp_tools, upload_file, wait_for_file_attachment +from .streaming_assertions import StreamingValidator @pytest.mark.parametrize("case", web_search_test_cases) @@ -195,6 +196,56 @@ def test_response_non_streaming_mcp_tool(compat_client, text_model_id, case): assert len(response.output) >= 3 +@pytest.mark.parametrize("case", mcp_tool_test_cases) +def test_response_sequential_mcp_tool(compat_client, text_model_id, case): + if not isinstance(compat_client, LlamaStackAsLibraryClient): + pytest.skip("in-process MCP server is only supported in library client") + + with make_mcp_server() as mcp_server_info: + tools = setup_mcp_tools(case.tools, mcp_server_info) + + response = compat_client.responses.create( + model=text_model_id, + input=case.input, + tools=tools, + stream=False, + ) + + assert len(response.output) >= 3 + list_tools = response.output[0] + assert list_tools.type == "mcp_list_tools" + assert list_tools.server_label == "localmcp" + assert len(list_tools.tools) == 2 + assert {t.name for t in list_tools.tools} == { + "get_boiling_point", + "greet_everyone", + } + + call = response.output[1] + assert call.type == "mcp_call" + assert call.name == "get_boiling_point" + assert json.loads(call.arguments) == { + "liquid_name": "myawesomeliquid", + "celsius": True, + } + assert call.error is None + assert "-100" in call.output + + # sometimes the model will call the tool again, so we need to get the last message + message = response.output[-1] + text_content = message.content[0].text + assert "boiling point" in text_content.lower() + + response2 = compat_client.responses.create( + model=text_model_id, input=case.input, tools=tools, stream=False, previous_response_id=response.id + ) + + assert len(response2.output) >= 1 + message = response2.output[-1] + text_content = message.content[0].text + assert "boiling point" in text_content.lower() + + @pytest.mark.parametrize("case", custom_tool_test_cases) def test_response_non_streaming_custom_tool(compat_client, text_model_id, case): response = compat_client.responses.create( diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index 5ea14d7c7..a964bc219 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -24,6 +24,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseMessage, OpenAIResponseObjectWithInput, OpenAIResponseOutputMessageContentOutputText, + OpenAIResponseOutputMessageMCPCall, OpenAIResponseOutputMessageWebSearchToolCall, OpenAIResponseText, OpenAIResponseTextFormat, @@ -461,6 +462,53 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_ assert input[3].content == "fake_input" +async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mock_responses_store): + """Test prepending a previous response which included an mcp tool call to a new response.""" + input_item_message = OpenAIResponseMessage( + id="123", + content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")], + role="user", + ) + output_tool_call = OpenAIResponseOutputMessageMCPCall( + id="ws_123", + name="fake-tool", + arguments="fake-arguments", + server_label="fake-label", + ) + output_message = OpenAIResponseMessage( + id="123", + content=[OpenAIResponseOutputMessageContentOutputText(text="fake_tool_call_response")], + status="completed", + role="assistant", + ) + response = OpenAIResponseObjectWithInput( + created_at=1, + id="resp_123", + model="fake_model", + output=[output_tool_call, output_message], + status="completed", + text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), + input=[input_item_message], + ) + mock_responses_store.get_response_object.return_value = response + + input_messages = [OpenAIResponseMessage(content="fake_input", role="user")] + input = await openai_responses_impl._prepend_previous_response(input_messages, "resp_123") + + assert len(input) == 4 + # Check for previous input + assert isinstance(input[0], OpenAIResponseMessage) + assert input[0].content[0].text == "fake_previous_input" + # Check for previous output MCP tool call + assert isinstance(input[1], OpenAIResponseOutputMessageMCPCall) + # Check for previous output web search response + assert isinstance(input[2], OpenAIResponseMessage) + assert input[2].content[0].text == "fake_tool_call_response" + # Check for new input + assert isinstance(input[3], OpenAIResponseMessage) + assert input[3].content == "fake_input" + + async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api): # Setup input_text = "What is the capital of Ireland?"