updated the comments

This commit is contained in:
Omar Abdelwahab 2025-11-14 16:56:04 -08:00
parent daa5a79b24
commit 0d4fa16ab9
4 changed files with 10 additions and 19 deletions

View file

@ -48,7 +48,6 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if mcp_endpoint is None: if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required") raise ValueError("mcp_endpoint is required")
# Phase 2: Only use the dedicated authorization parameter
# Get other headers from provider data (but NOT authorization) # Get other headers from provider data (but NOT authorization)
provider_headers = await self.get_headers_from_request(mcp_endpoint.uri) provider_headers = await self.get_headers_from_request(mcp_endpoint.uri)
@ -64,7 +63,6 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if urlparse(endpoint).scheme not in ("http", "https"): if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
# Phase 2: Only use the dedicated authorization parameter
# Get other headers from provider data (but NOT authorization) # Get other headers from provider data (but NOT authorization)
provider_headers = await self.get_headers_from_request(endpoint) provider_headers = await self.get_headers_from_request(endpoint)
@ -80,7 +78,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
""" """
Extract headers from request provider data, excluding authorization. Extract headers from request provider data, excluding authorization.
Phase 2: Authorization must be provided via the dedicated authorization parameter. Authorization must be provided via the dedicated authorization parameter.
If Authorization is found in mcp_headers, raise an error to guide users to the correct approach. If Authorization is found in mcp_headers, raise an error to guide users to the correct approach.
Args: Args:
@ -104,7 +102,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri): if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue continue
# Phase 2: Reject Authorization in mcp_headers - must use authorization parameter # Reject Authorization in mcp_headers - must use authorization parameter
for key in values.keys(): for key in values.keys():
if key.lower() == "authorization": if key.lower() == "authorization":
raise ValueError( raise ValueError(

View file

@ -191,7 +191,7 @@ class TestMCPToolsInChatCompletion:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Phase 2: Use the dedicated authorization parameter # Use the dedicated authorization parameter
# Get the tools from MCP # Get the tools from MCP
tools_response = llama_stack_client.tool_runtime.list_tools( tools_response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,

View file

@ -35,11 +35,8 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Phase 2: Use the dedicated authorization parameter (no more provider_data headers) # Use the dedicated authorization parameter (no more provider_data headers)
# This tests direct tool_runtime.invoke_tool API calls # This tests direct tool_runtime.invoke_tool API calls
# Note: tools.list() is the ToolGroups API and doesn't have authorization parameter
# Use tool_runtime.list_tools() for authorization support
tools_list = llama_stack_client.tool_runtime.list_tools( tools_list = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN, # Use dedicated authorization parameter authorization=AUTH_TOKEN, # Use dedicated authorization parameter

View file

@ -120,7 +120,7 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Phase 2: Use the dedicated authorization parameter # Use the dedicated authorization parameter
# List runtime tools # List runtime tools
response = llama_stack_client.tool_runtime.list_tools( response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
@ -160,7 +160,7 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Phase 2: Use the dedicated authorization parameter # Use the dedicated authorization parameter
# List tools # List tools
response = llama_stack_client.tool_runtime.list_tools( response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
@ -206,7 +206,7 @@ class TestMCPSchemaPreservation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Phase 2: Use the dedicated authorization parameter # Use the dedicated authorization parameter
response = llama_stack_client.tool_runtime.list_tools( response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN, authorization=AUTH_TOKEN,
@ -249,8 +249,7 @@ class TestMCPToolInvocation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Phase 2: Use the dedicated authorization parameter # Use the dedicated authorization parameter
# List tools to populate the tool index
llama_stack_client.tool_runtime.list_tools( llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN, authorization=AUTH_TOKEN,
@ -291,8 +290,7 @@ class TestMCPToolInvocation:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Phase 2: Use the dedicated authorization parameter # Use the dedicated authorization parameter
# List tools to populate the tool index
llama_stack_client.tool_runtime.list_tools( llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN, authorization=AUTH_TOKEN,
@ -343,9 +341,7 @@ class TestAgentWithMCPTools:
mcp_endpoint=dict(uri=uri), mcp_endpoint=dict(uri=uri),
) )
# Phase 2: Use the dedicated authorization parameter # Use the dedicated authorization parameter
# Note: tools.list() is the ToolGroups API and doesn't have authorization parameter
# Use tool_runtime.list_tools() instead
tools_list = llama_stack_client.tool_runtime.list_tools( tools_list = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN, authorization=AUTH_TOKEN,