precommit

This commit is contained in:
Omar Abdelwahab 2025-11-06 12:02:45 -08:00
parent ac9442eb92
commit 5ce48d2c6a
3 changed files with 35 additions and 90 deletions

View file

@ -45,9 +45,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if mcp_endpoint is None: if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required") raise ValueError("mcp_endpoint is required")
headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri) headers, authorization = await self.get_headers_from_request(mcp_endpoint.uri)
return await list_mcp_tools( return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization)
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization
)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name) tool = await self.tool_store.get_tool(tool_name)
@ -66,9 +64,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
authorization=authorization, authorization=authorization,
) )
async def get_headers_from_request( async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
self, mcp_endpoint_uri: str
) -> tuple[dict[str, str], str | None]:
""" """
Extract headers and authorization from request provider data. Extract headers and authorization from request provider data.

View file

@ -27,6 +27,7 @@ from llama_stack.providers.utils.tools.ttl_dict import TTLDict
logger = get_logger(__name__, category="tools") logger = get_logger(__name__, category="tools")
def prepare_mcp_headers(base_headers: dict[str, str] | None, authorization: str | None) -> dict[str, str]: def prepare_mcp_headers(base_headers: dict[str, str] | None, authorization: str | None) -> dict[str, str]:
""" """
Prepare headers for MCP requests with authorization support. Prepare headers for MCP requests with authorization support.

View file

@ -63,9 +63,7 @@ def test_response_non_streaming_file_search(
if isinstance(compat_client, LlamaStackAsLibraryClient): if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.") pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = new_vector_store( vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension)
compat_client, "test_vector_store", embedding_model_id, embedding_dimension
)
if case.file_content: if case.file_content:
file_name = "test_response_non_streaming_file_search.txt" file_name = "test_response_non_streaming_file_search.txt"
@ -122,9 +120,7 @@ def test_response_non_streaming_file_search_empty_vector_store(
if isinstance(compat_client, LlamaStackAsLibraryClient): if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.") pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = new_vector_store( vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension)
compat_client, "test_vector_store", embedding_model_id, embedding_dimension
)
# Create the response request, which should query our vector store # Create the response request, which should query our vector store
response = compat_client.responses.create( response = compat_client.responses.create(
@ -153,9 +149,7 @@ def test_response_sequential_file_search(
if isinstance(compat_client, LlamaStackAsLibraryClient): if isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API file search is not yet supported in library client.") pytest.skip("Responses API file search is not yet supported in library client.")
vector_store = new_vector_store( vector_store = new_vector_store(compat_client, "test_vector_store", embedding_model_id, embedding_dimension)
compat_client, "test_vector_store", embedding_model_id, embedding_dimension
)
# Create a test file with content # Create a test file with content
file_content = "The Llama 4 Maverick model has 128 experts in its mixture of experts architecture." file_content = "The Llama 4 Maverick model has 128 experts in its mixture of experts architecture."
@ -506,18 +500,10 @@ def test_response_function_call_ordering_2(compat_client, text_model_id):
stream=False, stream=False,
) )
for output in response.output: for output in response.output:
if ( if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
output.type == "function_call"
and output.status == "completed"
and output.name == "get_weather"
):
inputs.append(output) inputs.append(output)
for output in response.output: for output in response.output:
if ( if output.type == "function_call" and output.status == "completed" and output.name == "get_weather":
output.type == "function_call"
and output.status == "completed"
and output.name == "get_weather"
):
weather = "It is raining." weather = "It is raining."
if "Los Angeles" in output.arguments: if "Los Angeles" in output.arguments:
weather = "It is cloudy." weather = "It is cloudy."
@ -539,9 +525,7 @@ def test_response_function_call_ordering_2(compat_client, text_model_id):
@pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases) @pytest.mark.parametrize("case", multi_turn_tool_execution_test_cases)
def test_response_non_streaming_multi_turn_tool_execution( def test_response_non_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
compat_client, text_model_id, case
):
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence.""" """Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
if not isinstance(compat_client, LlamaStackAsLibraryClient): if not isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("in-process MCP server is only supported in library client") pytest.skip("in-process MCP server is only supported in library client")
@ -556,18 +540,12 @@ def test_response_non_streaming_multi_turn_tool_execution(
) )
# Verify we have MCP tool calls in the output # Verify we have MCP tool calls in the output
mcp_list_tools = [ mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
output for output in response.output if output.type == "mcp_list_tools"
]
mcp_calls = [output for output in response.output if output.type == "mcp_call"] mcp_calls = [output for output in response.output if output.type == "mcp_call"]
message_outputs = [ message_outputs = [output for output in response.output if output.type == "message"]
output for output in response.output if output.type == "message"
]
# Should have exactly 1 MCP list tools message (at the beginning) # Should have exactly 1 MCP list tools message (at the beginning)
assert ( assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
len(mcp_list_tools) == 1
), f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
assert mcp_list_tools[0].server_label == "localmcp" assert mcp_list_tools[0].server_label == "localmcp"
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
expected_tool_names = { expected_tool_names = {
@ -579,37 +557,25 @@ def test_response_non_streaming_multi_turn_tool_execution(
} }
assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names
assert ( assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
len(mcp_calls) >= 1
), f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
for mcp_call in mcp_calls: for mcp_call in mcp_calls:
assert ( assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
mcp_call.error is None
), f"MCP call should not have errors, got: {mcp_call.error}"
assert ( assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
len(message_outputs) >= 1
), f"Expected at least 1 message output, got {len(message_outputs)}"
final_message = message_outputs[-1] final_message = message_outputs[-1]
assert ( assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}"
final_message.role == "assistant" assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
), f"Final message should be from assistant, got {final_message.role}"
assert (
final_message.status == "completed"
), f"Final message should be completed, got {final_message.status}"
assert len(final_message.content) > 0, "Final message should have content" assert len(final_message.content) > 0, "Final message should have content"
expected_output = case.expected expected_output = case.expected
assert ( assert expected_output.lower() in response.output_text.lower(), (
expected_output.lower() in response.output_text.lower() f"Expected '{expected_output}' to appear in response: {response.output_text}"
), f"Expected '{expected_output}' to appear in response: {response.output_text}" )
@pytest.mark.parametrize("case", multi_turn_tool_execution_streaming_test_cases) @pytest.mark.parametrize("case", multi_turn_tool_execution_streaming_test_cases)
def test_response_streaming_multi_turn_tool_execution( def test_response_streaming_multi_turn_tool_execution(compat_client, text_model_id, case):
compat_client, text_model_id, case
):
"""Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence.""" """Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
if not isinstance(compat_client, LlamaStackAsLibraryClient): if not isinstance(compat_client, LlamaStackAsLibraryClient):
pytest.skip("in-process MCP server is only supported in library client") pytest.skip("in-process MCP server is only supported in library client")
@ -642,22 +608,12 @@ def test_response_streaming_multi_turn_tool_execution(
final_response = final_chunk.response final_response = final_chunk.response
# Verify multi-turn MCP tool execution results # Verify multi-turn MCP tool execution results
mcp_list_tools = [ mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"]
output mcp_calls = [output for output in final_response.output if output.type == "mcp_call"]
for output in final_response.output message_outputs = [output for output in final_response.output if output.type == "message"]
if output.type == "mcp_list_tools"
]
mcp_calls = [
output for output in final_response.output if output.type == "mcp_call"
]
message_outputs = [
output for output in final_response.output if output.type == "message"
]
# Should have exactly 1 MCP list tools message (at the beginning) # Should have exactly 1 MCP list tools message (at the beginning)
assert ( assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
len(mcp_list_tools) == 1
), f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
assert mcp_list_tools[0].server_label == "localmcp" assert mcp_list_tools[0].server_label == "localmcp"
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
expected_tool_names = { expected_tool_names = {
@ -670,33 +626,25 @@ def test_response_streaming_multi_turn_tool_execution(
assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names assert {t.name for t in mcp_list_tools[0].tools} == expected_tool_names
# Should have at least 1 MCP call (the model should call at least one tool) # Should have at least 1 MCP call (the model should call at least one tool)
assert ( assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
len(mcp_calls) >= 1
), f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
# All MCP calls should be completed (verifies our tool execution works) # All MCP calls should be completed (verifies our tool execution works)
for mcp_call in mcp_calls: for mcp_call in mcp_calls:
assert ( assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
mcp_call.error is None
), f"MCP call should not have errors, got: {mcp_call.error}"
# Should have at least one final message response # Should have at least one final message response
assert ( assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
len(message_outputs) >= 1
), f"Expected at least 1 message output, got {len(message_outputs)}"
# Final message should be from assistant and completed # Final message should be from assistant and completed
final_message = message_outputs[-1] final_message = message_outputs[-1]
assert ( assert final_message.role == "assistant", (
final_message.role == "assistant" f"Final message should be from assistant, got {final_message.role}"
), f"Final message should be from assistant, got {final_message.role}" )
assert ( assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
final_message.status == "completed"
), f"Final message should be completed, got {final_message.status}"
assert len(final_message.content) > 0, "Final message should have content" assert len(final_message.content) > 0, "Final message should have content"
# Check that the expected output appears in the response # Check that the expected output appears in the response
expected_output = case.expected expected_output = case.expected
assert ( assert expected_output.lower() in final_response.output_text.lower(), (
expected_output.lower() in final_response.output_text.lower() f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
), f"Expected '{expected_output}' to appear in response: {final_response.output_text}" )