diff --git a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 61402707c..92a7d788e 100644 --- a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -45,9 +45,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime if mcp_endpoint is None: raise ValueError("mcp_endpoint is required") headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri) - return await list_mcp_tools( - endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization - ) + return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization) async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: tool = await self.tool_store.get_tool(tool_name) @@ -66,9 +64,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime authorization=authorization, ) - async def get_headers_from_request( - self, mcp_endpoint_uri: str - ) -> tuple[dict[str, str], str | None]: + async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]: """ Extract headers and authorization from request provider data. diff --git a/src/llama_stack/providers/utils/tools/mcp.py b/src/llama_stack/providers/utils/tools/mcp.py index 2bc6cbf96..573054e25 100644 --- a/src/llama_stack/providers/utils/tools/mcp.py +++ b/src/llama_stack/providers/utils/tools/mcp.py @@ -27,6 +27,7 @@ from llama_stack.providers.utils.tools.ttl_dict import TTLDict logger = get_logger(__name__, category="tools") + def prepare_mcp_headers(base_headers: dict[str, str] | None, authorization: str | None) -> dict[str, str]: """ Prepare headers for MCP requests with authorization support. diff --git a/tests/integration/responses/test_tool_responses.py b/tests/integration/responses/test_tool_responses.py index 4501961f3..1228f8a85 100644 --- a/tests/integration/responses/test_tool_responses.py +++ b/tests/integration/responses/test_tool_responses.py @@ -63,9 +63,7 @@ def test_response_non_streaming_file_search( if isinstance(compat_client, LlamaStackAsLibraryClient): pytest.skip("Responses API file search is not yet supported in library client.") - vector_store = new_vector_store( - compat_client, "test_vector_store", embedding_model_id, embedding_dimension - ) + vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension) if case.file_content: file_name = "test_response_non_streaming_file_search.txt" @@ -122,9 +120,7 @@ def test_response_non_streaming_file_search_empty_vector_store( if isinstance(compat_client, LlamaStackAsLibraryClient): pytest.skip("Responses API file search is not yet supported in library client.") - vector_store = new_vector_store( - compat_client, "test_vector_store", embedding_model_id, embedding_dimension - ) + vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension) # Create the response request, which should query our vector store response = compat_client.responses.create( @@ -153,9 +149,7 @@ def test_response_sequential_file_search( if isinstance(compat_client, LlamaStackAsLibraryClient): pytest.skip("Responses API file search is not yet supported in library client.") - vector_store = new_vector_store( - compat_client, "test_vector_store", embedding_model_id, embedding_dimension - ) + vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension) # Create a test file with content file_content = "The Llama 4 Maverick model has 128 experts in its mixture of experts architecture." @@ -506,18 +500,10 @@ def test_response_function_call_ordering_2(compat_client, text_model_id): stream=False, ) for output in response.output: - if ( - output.type == "function_call" - and output.status == "completed" - and output.name == "get_weather" - ): + if output.type == "function_call" and output.status == "completed" and output.name == "get_weather": inputs.append(output) for output in response.output: - if ( - output.type == "function_call" - and output.status == "completed" - and output.name == "get_weather" - ): + if output.type == "function_call" and output.status == "completed" and output.name == "get_weather": weather = "It is raining." if "Los Angeles" in output.arguments: weather = "It is cloudy." @@ -539,9 +525,7 @@ def test_response_function_call_ordering_2(compat_client, text_model_id): @pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases) -def test_response_non_streaming_multi_turn_tool_execution( - compat_client, text_model_id, case -): +def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case): """Test multi-turn tool execution where multiple MCP tool calls are performed in sequence.""" if not isinstance(compat_client, LlamaStackAsLibraryClient): pytest.skip("in-process MCP server is only supported in library client") @@ -556,18 +540,12 @@ def test_response_non_streaming_multi_turn_tool_execution( ) # Verify we have MCP tool calls in the output - mcp_list_tools = [ - output for output in response.output if output.type == "mcp_list_tools" - ] + mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"] mcp_calls = [output for output in response.output if output.type == "mcp_call"] - message_outputs = [ - output for output in response.output if output.type == "message" - ] + message_outputs = [output for output in response.output if output.type == "message"] # Should have exactly 1 MCP list tools message (at the beginning) - assert ( - len(mcp_list_tools) == 1 - ), f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}" + assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}" assert mcp_list_tools[0].server_label == "localmcp" assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools expected_tool_names = { @@ -579,37 +557,25 @@ def test_response_non_streaming_multi_turn_tool_execution( } assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names - assert ( - len(mcp_calls) >= 1 - ), f"Expected at least 1 mcp_call, got {len(mcp_calls)}" + assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}" for mcp_call in mcp_calls: - assert ( - mcp_call.error is None - ), f"MCP call should not have errors, got: {mcp_call.error}" + assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}" - assert ( - len(message_outputs) >= 1 - ), f"Expected at least 1 message output, got {len(message_outputs)}" + assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}" final_message = message_outputs[-1] - assert ( - final_message.role == "assistant" - ), f"Final message should be from assistant, got {final_message.role}" - assert ( - final_message.status == "completed" - ), f"Final message should be completed, got {final_message.status}" + assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}" + assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}" assert len(final_message.content) > 0, "Final message should have content" expected_output = case.expected - assert ( - expected_output.lower() in response.output_text.lower() - ), f"Expected '{expected_output}' to appear in response: {response.output_text}" + assert expected_output.lower() in response.output_text.lower(), ( + f"Expected '{expected_output}' to appear in response: {response.output_text}" + ) @pytest.mark.parametrize("case", multi_turn_tool_execution_streaming_test_cases) -def test_response_streaming_multi_turn_tool_execution( - compat_client, text_model_id, case -): +def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_id, case): """Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence.""" if not isinstance(compat_client, LlamaStackAsLibraryClient): pytest.skip("in-process MCP server is only supported in library client") @@ -642,22 +608,12 @@ def test_response_streaming_multi_turn_tool_execution( final_response = final_chunk.response # Verify multi-turn MCP tool execution results - mcp_list_tools = [ - output - for output in final_response.output - if output.type == "mcp_list_tools" - ] - mcp_calls = [ - output for output in final_response.output if output.type == "mcp_call" - ] - message_outputs = [ - output for output in final_response.output if output.type == "message" - ] + mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"] + mcp_calls = [output for output in final_response.output if output.type == "mcp_call"] + message_outputs = [output for output in final_response.output if output.type == "message"] # Should have exactly 1 MCP list tools message (at the beginning) - assert ( - len(mcp_list_tools) == 1 - ), f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}" + assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}" assert mcp_list_tools[0].server_label == "localmcp" assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools expected_tool_names = { @@ -670,33 +626,25 @@ def test_response_streaming_multi_turn_tool_execution( assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names # Should have at least 1 MCP call (the model should call at least one tool) - assert ( - len(mcp_calls) >= 1 - ), f"Expected at least 1 mcp_call, got {len(mcp_calls)}" + assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}" # All MCP calls should be completed (verifies our tool execution works) for mcp_call in mcp_calls: - assert ( - mcp_call.error is None - ), f"MCP call should not have errors, got: {mcp_call.error}" + assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}" # Should have at least one final message response - assert ( - len(message_outputs) >= 1 - ), f"Expected at least 1 message output, got {len(message_outputs)}" + assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}" # Final message should be from assistant and completed final_message = message_outputs[-1] - assert ( - final_message.role == "assistant" - ), f"Final message should be from assistant, got {final_message.role}" - assert ( - final_message.status == "completed" - ), f"Final message should be completed, got {final_message.status}" + assert final_message.role == "assistant", ( + f"Final message should be from assistant, got {final_message.role}" + ) + assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}" assert len(final_message.content) > 0, "Final message should have content" # Check that the expected output appears in the response expected_output = case.expected - assert ( - expected_output.lower() in final_response.output_text.lower() - ), f"Expected '{expected_output}' to appear in response: {final_response.output_text}" + assert expected_output.lower() in final_response.output_text.lower(), ( + f"Expected '{expected_output}' to appear in response: {final_response.output_text}" + )