diff --git a/src/llama_stack/apis/tools/tools.py b/src/llama_stack/apis/tools/tools.py index de39a4294..06580dc74 100644 --- a/src/llama_stack/apis/tools/tools.py +++ b/src/llama_stack/apis/tools/tools.py @@ -199,11 +199,13 @@ class ToolRuntime(Protocol): self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None, + authorization: str | None = None, ) -> ListToolDefsResponse: """List all tools in the runtime. :param tool_group_id: The ID of the tool group to list tools for. :param mcp_endpoint: The MCP endpoint to use for the tool group. + :param authorization: (Optional) OAuth access token for authenticating with the MCP server. :returns: A ListToolDefsResponse. """ ... diff --git a/src/llama_stack/core/routers/tool_runtime.py b/src/llama_stack/core/routers/tool_runtime.py index cd690985e..3cfe584c5 100644 --- a/src/llama_stack/core/routers/tool_runtime.py +++ b/src/llama_stack/core/routers/tool_runtime.py @@ -46,6 +46,6 @@ class ToolRuntimeRouter(ToolRuntime): ) async def list_runtime_tools( - self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None + self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None, authorization: str | None = None ) -> ListToolDefsResponse: - return await self.routing_table.list_tools(tool_group_id) + return await self.routing_table.list_tools(tool_group_id, authorization=authorization) diff --git a/src/llama_stack/core/routing_tables/toolgroups.py b/src/llama_stack/core/routing_tables/toolgroups.py index 573c3444d..0761c5582 100644 --- a/src/llama_stack/core/routing_tables/toolgroups.py +++ b/src/llama_stack/core/routing_tables/toolgroups.py @@ -43,7 +43,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): routing_key = self.tool_to_toolgroup[routing_key] return await super().get_provider_impl(routing_key, provider_id) - async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse: + async def list_tools(self, toolgroup_id: str | None = None, authorization: str | None = None) -> ListToolDefsResponse: if toolgroup_id: if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id): toolgroup_id = group_id @@ -55,7 +55,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): for toolgroup in toolgroups: if toolgroup.identifier not in self.toolgroups_to_tools: try: - await self._index_tools(toolgroup) + await self._index_tools(toolgroup, authorization=authorization) except AuthenticationRequiredError: # Send authentication errors back to the client so it knows # that it needs to supply credentials for remote MCP servers. @@ -70,10 +70,10 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): return ListToolDefsResponse(data=all_tools) - async def _index_tools(self, toolgroup: ToolGroup): + async def _index_tools(self, toolgroup: ToolGroup, authorization: str | None = None): provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id) tooldefs_response = await provider_impl.list_runtime_tools( - toolgroup.identifier, toolgroup.mcp_endpoint + toolgroup.identifier, toolgroup.mcp_endpoint, authorization=authorization ) tooldefs = tooldefs_response.data diff --git a/src/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/src/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py index 31247aa76..e8ab6dc90 100644 --- a/src/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +++ b/src/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py @@ -52,6 +52,7 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None, + authorization: str | None = None, ) -> ListToolDefsResponse: return ListToolDefsResponse( data=[ diff --git a/src/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/src/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index cf9b5f6b6..081082add 100644 --- a/src/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/src/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -51,6 +51,7 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None, + authorization: str | None = None, ) -> ListToolDefsResponse: return ListToolDefsResponse( data=[ diff --git a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 4035df5c1..4ad2d4b3a 100644 --- a/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -44,14 +44,16 @@ class ModelContextProtocolToolRuntimeImpl( self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None, + authorization: str | None = None, ) -> ListToolDefsResponse: # this endpoint should be retrieved by getting the tool group right? if mcp_endpoint is None: raise ValueError("mcp_endpoint is required") - # MCP tool listing typically doesn't require authorization + + # Use authorization parameter for MCP servers that require auth headers = {} return await list_mcp_tools( - endpoint=mcp_endpoint.uri, headers=headers, authorization=None + endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization ) async def invoke_tool( diff --git a/src/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/src/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index 26429053f..1b49f8a03 100644 --- a/src/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/src/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -51,6 +51,7 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None, + authorization: str | None = None, ) -> ListToolDefsResponse: return ListToolDefsResponse( data=[ diff --git a/src/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/src/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py index 6fd6b60b1..9bacfaa1c 100644 --- a/src/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +++ b/src/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -52,6 +52,7 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None, + authorization: str | None = None, ) -> ListToolDefsResponse: return ListToolDefsResponse( data=[