mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
fix(tool-runtime): Remove authorization from list_runtime_tools()
The authorization parameter should only be on invoke_tool(), not on list_runtime_tools(). Tool listing typically doesn't require authentication, and the client SDK doesn't have this parameter yet. Changes: 1. Removed authorization parameter from ToolRuntime.list_runtime_tools() protocol method 2. Updated all implementations to remove the authorization parameter: - MCPProviderImpl.list_runtime_tools() - ToolRuntimeRouter.list_runtime_tools() - ToolGroupsRoutingTable.list_tools() and _index_tools() 3. Updated test to remove authorization from list_tools() call This ensures compatibility with the llama-stack-client SDK which doesn't support authorization on list_tools() yet. Only invoke_tool() requires and accepts the authorization parameter for authenticated tool execution.
This commit is contained in:
parent
c0295a2495
commit
18f197763b
5 changed files with 8 additions and 13 deletions
|
|
@ -199,13 +199,11 @@ class ToolRuntime(Protocol):
|
||||||
self,
|
self,
|
||||||
tool_group_id: str | None = None,
|
tool_group_id: str | None = None,
|
||||||
mcp_endpoint: URL | None = None,
|
mcp_endpoint: URL | None = None,
|
||||||
authorization: str | None = None,
|
|
||||||
) -> ListToolDefsResponse:
|
) -> ListToolDefsResponse:
|
||||||
"""List all tools in the runtime.
|
"""List all tools in the runtime.
|
||||||
|
|
||||||
:param tool_group_id: The ID of the tool group to list tools for.
|
:param tool_group_id: The ID of the tool group to list tools for.
|
||||||
:param mcp_endpoint: The MCP endpoint to use for the tool group.
|
:param mcp_endpoint: The MCP endpoint to use for the tool group.
|
||||||
:param authorization: (Optional) OAuth access token for authenticating with the MCP server.
|
|
||||||
:returns: A ListToolDefsResponse.
|
:returns: A ListToolDefsResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,6 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_runtime_tools(
|
async def list_runtime_tools(
|
||||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None, authorization: str | None = None
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
) -> ListToolDefsResponse:
|
) -> ListToolDefsResponse:
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
return await self.routing_table.list_tools(tool_group_id)
|
||||||
return await self.routing_table.list_tools(tool_group_id, authorization=authorization)
|
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
routing_key = self.tool_to_toolgroup[routing_key]
|
routing_key = self.tool_to_toolgroup[routing_key]
|
||||||
return await super().get_provider_impl(routing_key, provider_id)
|
return await super().get_provider_impl(routing_key, provider_id)
|
||||||
|
|
||||||
async def list_tools(self, toolgroup_id: str | None = None, authorization: str | None = None) -> ListToolDefsResponse:
|
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
|
||||||
if toolgroup_id:
|
if toolgroup_id:
|
||||||
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
||||||
toolgroup_id = group_id
|
toolgroup_id = group_id
|
||||||
|
|
@ -55,7 +55,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
for toolgroup in toolgroups:
|
for toolgroup in toolgroups:
|
||||||
if toolgroup.identifier not in self.toolgroups_to_tools:
|
if toolgroup.identifier not in self.toolgroups_to_tools:
|
||||||
try:
|
try:
|
||||||
await self._index_tools(toolgroup, authorization=authorization)
|
await self._index_tools(toolgroup)
|
||||||
except AuthenticationRequiredError:
|
except AuthenticationRequiredError:
|
||||||
# Send authentication errors back to the client so it knows
|
# Send authentication errors back to the client so it knows
|
||||||
# that it needs to supply credentials for remote MCP servers.
|
# that it needs to supply credentials for remote MCP servers.
|
||||||
|
|
@ -70,10 +70,10 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
|
|
||||||
return ListToolDefsResponse(data=all_tools)
|
return ListToolDefsResponse(data=all_tools)
|
||||||
|
|
||||||
async def _index_tools(self, toolgroup: ToolGroup, authorization: str | None = None):
|
async def _index_tools(self, toolgroup: ToolGroup):
|
||||||
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||||
tooldefs_response = await provider_impl.list_runtime_tools(
|
tooldefs_response = await provider_impl.list_runtime_tools(
|
||||||
toolgroup.identifier, toolgroup.mcp_endpoint, authorization=authorization
|
toolgroup.identifier, toolgroup.mcp_endpoint
|
||||||
)
|
)
|
||||||
|
|
||||||
tooldefs = tooldefs_response.data
|
tooldefs = tooldefs_response.data
|
||||||
|
|
|
||||||
|
|
@ -44,15 +44,14 @@ class ModelContextProtocolToolRuntimeImpl(
|
||||||
self,
|
self,
|
||||||
tool_group_id: str | None = None,
|
tool_group_id: str | None = None,
|
||||||
mcp_endpoint: URL | None = None,
|
mcp_endpoint: URL | None = None,
|
||||||
authorization: str | None = None,
|
|
||||||
) -> ListToolDefsResponse:
|
) -> ListToolDefsResponse:
|
||||||
# this endpoint should be retrieved by getting the tool group right?
|
# this endpoint should be retrieved by getting the tool group right?
|
||||||
if mcp_endpoint is None:
|
if mcp_endpoint is None:
|
||||||
raise ValueError("mcp_endpoint is required")
|
raise ValueError("mcp_endpoint is required")
|
||||||
# Authorization now comes from request body parameter (not provider-data)
|
# MCP tool listing typically doesn't require authorization
|
||||||
headers = {}
|
headers = {}
|
||||||
return await list_mcp_tools(
|
return await list_mcp_tools(
|
||||||
endpoint=mcp_endpoint.uri, headers=headers, authorization=authorization
|
endpoint=mcp_endpoint.uri, headers=headers, authorization=None
|
||||||
)
|
)
|
||||||
|
|
||||||
async def invoke_tool(
|
async def invoke_tool(
|
||||||
|
|
|
||||||
|
|
@ -196,7 +196,6 @@ class TestMCPToolsInChatCompletion:
|
||||||
# Get the tools from MCP
|
# Get the tools from MCP
|
||||||
tools_response = llama_stack_client.tool_runtime.list_tools(
|
tools_response = llama_stack_client.tool_runtime.list_tools(
|
||||||
tool_group_id=test_toolgroup_id,
|
tool_group_id=test_toolgroup_id,
|
||||||
authorization=AUTH_TOKEN,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to OpenAI format for inference
|
# Convert to OpenAI format for inference
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue