fix(tool-runtime): Remove authorization from list_runtime_tools()

The authorization parameter should only be on invoke_tool(), not on
list_runtime_tools(). Tool listing typically doesn't require authentication,
and the client SDK doesn't have this parameter yet.

Changes:
1. Removed authorization parameter from ToolRuntime.list_runtime_tools() protocol method
2. Updated all implementations to remove the authorization parameter:
   - MCPProviderImpl.list_runtime_tools()
   - ToolRuntimeRouter.list_runtime_tools()
   - ToolGroupsRoutingTable.list_tools() and _index_tools()
3. Updated test to remove authorization from list_tools() call

This ensures compatibility with the llama-stack-client SDK which doesn't
support authorization on list_tools() yet. Only invoke_tool() requires
and accepts the authorization parameter for authenticated tool execution.
This commit is contained in:
Omar Abdelwahab 2025-11-12 16:17:53 -08:00
parent c0295a2495
commit 18f197763b
5 changed files with 8 additions and 13 deletions

View file

@ -199,13 +199,11 @@ class ToolRuntime(Protocol):
self, self,
tool_group_id: str | None = None, tool_group_id: str | None = None,
mcp_endpoint: URL | None = None, mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse: ) -> ListToolDefsResponse:
"""List all tools in the runtime. """List all tools in the runtime.
:param tool_group_id: The ID of the tool group to list tools for. :param tool_group_id: The ID of the tool group to list tools for.
:param mcp_endpoint: The MCP endpoint to use for the tool group. :param mcp_endpoint: The MCP endpoint to use for the tool group.
:param authorization: (Optional) OAuth access token for authenticating with the MCP server.
:returns: A ListToolDefsResponse. :returns: A ListToolDefsResponse.
""" """
... ...

View file

@ -46,7 +46,6 @@ class ToolRuntimeRouter(ToolRuntime):
) )
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None, authorization: str | None = None self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse: ) -> ListToolDefsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") return await self.routing_table.list_tools(tool_group_id)
return await self.routing_table.list_tools(tool_group_id, authorization=authorization)

View file

@ -43,7 +43,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
routing_key = self.tool_to_toolgroup[routing_key] routing_key = self.tool_to_toolgroup[routing_key]
return await super().get_provider_impl(routing_key, provider_id) return await super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None, authorization: str | None = None) -> ListToolDefsResponse: async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
if toolgroup_id: if toolgroup_id:
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id): if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
toolgroup_id = group_id toolgroup_id = group_id
@ -55,7 +55,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
for toolgroup in toolgroups: for toolgroup in toolgroups:
if toolgroup.identifier not in self.toolgroups_to_tools: if toolgroup.identifier not in self.toolgroups_to_tools:
try: try:
await self._index_tools(toolgroup, authorization=authorization) await self._index_tools(toolgroup)
except AuthenticationRequiredError: except AuthenticationRequiredError:
# Send authentication errors back to the client so it knows # Send authentication errors back to the client so it knows
# that it needs to supply credentials for remote MCP servers. # that it needs to supply credentials for remote MCP servers.
@ -70,10 +70,10 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return ListToolDefsResponse(data=all_tools) return ListToolDefsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup, authorization: str | None = None): async def _index_tools(self, toolgroup: ToolGroup):
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id) provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools( tooldefs_response = await provider_impl.list_runtime_tools(
toolgroup.identifier, toolgroup.mcp_endpoint, authorization=authorization toolgroup.identifier, toolgroup.mcp_endpoint
) )
tooldefs = tooldefs_response.data tooldefs = tooldefs_response.data

View file

@ -44,15 +44,14 @@ class ModelContextProtocolToolRuntimeImpl(
self, self,
tool_group_id: str | None = None, tool_group_id: str | None = None,
mcp_endpoint: URL | None = None, mcp_endpoint: URL | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse: ) -> ListToolDefsResponse:
# this endpoint should be retrieved by getting the tool group right? # this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None: if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required") raise ValueError("mcp_endpoint is required")
# Authorization now comes from request body parameter (not provider-data) # MCP tool listing typically doesn't require authorization
headers = {} headers = {}
return await list_mcp_tools( return await list_mcp_tools(
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization endpoint=mcp_endpoint.uri, headers=headers, authorization=None
) )
async def invoke_tool( async def invoke_tool(

View file

@ -196,7 +196,6 @@ class TestMCPToolsInChatCompletion:
# Get the tools from MCP # Get the tools from MCP
tools_response = llama_stack_client.tool_runtime.list_tools( tools_response = llama_stack_client.tool_runtime.list_tools(
tool_group_id=test_toolgroup_id, tool_group_id=test_toolgroup_id,
authorization=AUTH_TOKEN,
) )
# Convert to OpenAI format for inference # Convert to OpenAI format for inference