mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
fix: add authorization parameter to all ToolRuntime provider implementations
Updated all ToolRuntime provider implementations to match the protocol signature: - BraveSearchToolRuntimeImpl - TavilySearchToolRuntimeImpl - BingSearchToolRuntimeImpl - WolframAlphaToolRuntimeImpl - MemoryToolRuntimeImpl This fixes the signature mismatch error in CI where protocol had 'authorization' parameter but implementations didn't.
This commit is contained in:
parent
84baa5c406
commit
d0ec3b07b5
5 changed files with 35 additions and 10 deletions
|
|
@ -279,7 +279,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime):
|
|||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
# Parameters are not listed since these methods are not yet invoked automatically
|
||||
# by the LLM. The method is only implemented so things like /tools can list without
|
||||
|
|
@ -307,7 +310,9 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime):
|
|||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
|
||||
) -> ToolInvocationResult:
|
||||
vector_store_ids = kwargs.get("vector_store_ids", [])
|
||||
query_config = kwargs.get("query_config")
|
||||
if query_config:
|
||||
|
|
|
|||
|
|
@ -49,7 +49,10 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq
|
|||
return provider_data.bing_search_api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
|
|
@ -70,7 +73,9 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq
|
|||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": api_key,
|
||||
|
|
|
|||
|
|
@ -48,7 +48,10 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe
|
|||
return provider_data.brave_search_api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
|
|
@ -70,7 +73,9 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe
|
|||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
url = "https://api.search.brave.com/res/v1/web/search"
|
||||
headers = {
|
||||
|
|
|
|||
|
|
@ -48,7 +48,10 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
|||
return provider_data.tavily_search_api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
|
|
@ -69,7 +72,9 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
|||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
|
|
|
|||
|
|
@ -49,7 +49,10 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
|||
return provider_data.wolfram_alpha_api_key
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
|
|
@ -70,7 +73,9 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
|||
]
|
||||
)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
async def invoke_tool(
|
||||
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
|
||||
) -> ToolInvocationResult:
|
||||
api_key = self._get_api_key()
|
||||
params = {
|
||||
"input": kwargs["query"],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue