From d0ec3b07b56db27ac1def018c22e14f7480b1fe6 Mon Sep 17 00:00:00 2001 From: Omar Abdelwahab Date: Wed, 12 Nov 2025 14:47:22 -0800 Subject: [PATCH] fix: add authorization parameter to all ToolRuntime provider implementations Updated all ToolRuntime provider implementations to match the protocol signature: - BraveSearchToolRuntimeImpl - TavilySearchToolRuntimeImpl - BingSearchToolRuntimeImpl - WolframAlphaToolRuntimeImpl - MemoryToolRuntimeImpl This fixes the signature mismatch error in CI where protocol had 'authorization' parameter but implementations didn't. --- .../providers/inline/tool_runtime/rag/memory.py | 9 +++++++-- .../remote/tool_runtime/bing_search/bing_search.py | 9 +++++++-- .../remote/tool_runtime/brave_search/brave_search.py | 9 +++++++-- .../remote/tool_runtime/tavily_search/tavily_search.py | 9 +++++++-- .../remote/tool_runtime/wolfram_alpha/wolfram_alpha.py | 9 +++++++-- 5 files changed, 35 insertions(+), 10 deletions(-) diff --git a/src/llama_stack/providers/inline/tool_runtime/rag/memory.py b/src/llama_stack/providers/inline/tool_runtime/rag/memory.py index 6a59be0ca..ab3833936 100644 --- a/src/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/src/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -279,7 +279,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime): ) async def list_runtime_tools( - self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None + self, + tool_group_id: str | None = None, + mcp_endpoint: URL | None = None, + authorization: str | None = None, ) -> ListToolDefsResponse: # Parameters are not listed since these methods are not yet invoked automatically # by the LLM. The method is only implemented so things like /tools can list without @@ -307,7 +310,9 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime): ] ) - async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: + async def invoke_tool( + self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None + ) -> ToolInvocationResult: vector_store_ids = kwargs.get("vector_store_ids", []) query_config = kwargs.get("query_config") if query_config: diff --git a/src/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/src/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py index 9a98964b7..e8ab6dc90 100644 --- a/src/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +++ b/src/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py @@ -49,7 +49,10 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq return provider_data.bing_search_api_key async def list_runtime_tools( - self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None + self, + tool_group_id: str | None = None, + mcp_endpoint: URL | None = None, + authorization: str | None = None, ) -> ListToolDefsResponse: return ListToolDefsResponse( data=[ @@ -70,7 +73,9 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq ] ) - async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: + async def invoke_tool( + self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None + ) -> ToolInvocationResult: api_key = self._get_api_key() headers = { "Ocp-Apim-Subscription-Key": api_key, diff --git a/src/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/src/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index 02e5b5c69..081082add 100644 --- a/src/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/src/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -48,7 +48,10 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe return provider_data.brave_search_api_key async def list_runtime_tools( - self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None + self, + tool_group_id: str | None = None, + mcp_endpoint: URL | None = None, + authorization: str | None = None, ) -> ListToolDefsResponse: return ListToolDefsResponse( data=[ @@ -70,7 +73,9 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe ] ) - async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: + async def invoke_tool( + self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None + ) -> ToolInvocationResult: api_key = self._get_api_key() url = "https://api.search.brave.com/res/v1/web/search" headers = { diff --git a/src/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/src/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index ca629fced..1b49f8a03 100644 --- a/src/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/src/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -48,7 +48,10 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR return provider_data.tavily_search_api_key async def list_runtime_tools( - self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None + self, + tool_group_id: str | None = None, + mcp_endpoint: URL | None = None, + authorization: str | None = None, ) -> ListToolDefsResponse: return ListToolDefsResponse( data=[ @@ -69,7 +72,9 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR ] ) - async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: + async def invoke_tool( + self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None + ) -> ToolInvocationResult: api_key = self._get_api_key() async with httpx.AsyncClient() as client: response = await client.post( diff --git a/src/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/src/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py index 410e34195..9bacfaa1c 100644 --- a/src/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +++ b/src/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -49,7 +49,10 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR return provider_data.wolfram_alpha_api_key async def list_runtime_tools( - self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None + self, + tool_group_id: str | None = None, + mcp_endpoint: URL | None = None, + authorization: str | None = None, ) -> ListToolDefsResponse: return ListToolDefsResponse( data=[ @@ -70,7 +73,9 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR ] ) - async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: + async def invoke_tool( + self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None + ) -> ToolInvocationResult: api_key = self._get_api_key() params = { "input": kwargs["query"],