mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
feat(tool-runtime): Add authorization parameter to list_runtime_tools
Add authorization parameter to list_runtime_tools() method to support MCP servers that require authentication for listing tools. Changes: - Updated ToolRuntime protocol to include authorization parameter on list_runtime_tools() - Updated all provider implementations (MCP, Tavily, Brave, Bing, Wolfram Alpha) - Updated router and routing table to pass authorization through - Updated API recorder patched methods to include authorization parameter This enables authenticated tool listing for enterprise MCP deployments where IT administrators pre-configure connectors requiring authentication. Note: Client SDK will need to be regenerated from updated OpenAPI spec to support passing this parameter from client code. Tests will pass once client SDK is updated.
This commit is contained in:
parent
e6ebbd8a7b
commit
66ca51ac0d
8 changed files with 16 additions and 8 deletions
|
|
@ -199,11 +199,13 @@ class ToolRuntime(Protocol):
|
|||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
"""List all tools in the runtime.
|
||||
|
||||
:param tool_group_id: The ID of the tool group to list tools for.
|
||||
:param mcp_endpoint: The MCP endpoint to use for the tool group.
|
||||
:param authorization: (Optional) OAuth access token for authenticating with the MCP server.
|
||||
:returns: A ListToolDefsResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -46,6 +46,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None, authorization: str | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
return await self.routing_table.list_tools(tool_group_id)
|
||||
return await self.routing_table.list_tools(tool_group_id, authorization=authorization)
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
routing_key = self.tool_to_toolgroup[routing_key]
|
||||
return await super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
|
||||
async def list_tools(self, toolgroup_id: str | None = None, authorization: str | None = None) -> ListToolDefsResponse:
|
||||
if toolgroup_id:
|
||||
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
||||
toolgroup_id = group_id
|
||||
|
|
@ -55,7 +55,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
for toolgroup in toolgroups:
|
||||
if toolgroup.identifier not in self.toolgroups_to_tools:
|
||||
try:
|
||||
await self._index_tools(toolgroup)
|
||||
await self._index_tools(toolgroup, authorization=authorization)
|
||||
except AuthenticationRequiredError:
|
||||
# Send authentication errors back to the client so it knows
|
||||
# that it needs to supply credentials for remote MCP servers.
|
||||
|
|
@ -70,10 +70,10 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
|
||||
return ListToolDefsResponse(data=all_tools)
|
||||
|
||||
async def _index_tools(self, toolgroup: ToolGroup):
|
||||
async def _index_tools(self, toolgroup: ToolGroup, authorization: str | None = None):
|
||||
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(
|
||||
toolgroup.identifier, toolgroup.mcp_endpoint
|
||||
toolgroup.identifier, toolgroup.mcp_endpoint, authorization=authorization
|
||||
)
|
||||
|
||||
tooldefs = tooldefs_response.data
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsReq
|
|||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRe
|
|||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
|
|
|
|||
|
|
@ -44,14 +44,16 @@ class ModelContextProtocolToolRuntimeImpl(
|
|||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
# this endpoint should be retrieved by getting the tool group right?
|
||||
if mcp_endpoint is None:
|
||||
raise ValueError("mcp_endpoint is required")
|
||||
# MCP tool listing typically doesn't require authorization
|
||||
|
||||
# Use authorization parameter for MCP servers that require auth
|
||||
headers = {}
|
||||
return await list_mcp_tools(
|
||||
endpoint=mcp_endpoint.uri, headers=headers, authorization=None
|
||||
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization
|
||||
)
|
||||
|
||||
async def invoke_tool(
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
|||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsR
|
|||
self,
|
||||
tool_group_id: str | None = None,
|
||||
mcp_endpoint: URL | None = None,
|
||||
authorization: str | None = None,
|
||||
) -> ListToolDefsResponse:
|
||||
return ListToolDefsResponse(
|
||||
data=[
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue