mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix(tools): do not index tools, only index toolgroups (#2261)
When registering a MCP endpoint, we cannot list tools (like we used to) since the MCP endpoint may be behind an auth wall. Registration can happen much sooner (via run.yaml). Instead, we do listing only when the _user_ actually calls listing. Furthermore, we cache the list in-memory in the server. Currently, the cache is not invalidated -- we may want to periodically re-list for MCP servers. Note that they must call `list_tools` before calling `invoke_tool` -- we use this critically. This will enable us to list MCP servers in run.yaml ## Test Plan Existing tests, updated tests accordingly.
This commit is contained in:
parent
5a422e236c
commit
ce33d02443
19 changed files with 131 additions and 153 deletions
|
@ -7,11 +7,8 @@
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups, ToolHost
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ToolGroupWithACL,
|
||||
ToolWithACL,
|
||||
)
|
||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||
from llama_stack.distribution.datatypes import ToolGroupWithACL
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
@ -20,11 +17,51 @@ logger = get_logger(name=__name__, category="core")
|
|||
|
||||
|
||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||
toolgroups_to_tools: dict[str, list[Tool]] = {}
|
||||
tool_to_toolgroup: dict[str, str] = {}
|
||||
|
||||
# overridden
|
||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
||||
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
||||
if routing_key in self.tool_to_toolgroup:
|
||||
routing_key = self.tool_to_toolgroup[routing_key]
|
||||
return super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
tools = await self.get_all_with_type("tool")
|
||||
if toolgroup_id:
|
||||
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
||||
return ListToolsResponse(data=tools)
|
||||
toolgroups = [await self.get_tool_group(toolgroup_id)]
|
||||
else:
|
||||
toolgroups = await self.get_all_with_type("tool_group")
|
||||
|
||||
all_tools = []
|
||||
for toolgroup in toolgroups:
|
||||
group_id = toolgroup.identifier
|
||||
if group_id not in self.toolgroups_to_tools:
|
||||
provider_impl = super().get_provider_impl(group_id, toolgroup.provider_id)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint)
|
||||
|
||||
# TODO: kill this Tool vs ToolDef distinction
|
||||
tooldefs = tooldefs_response.data
|
||||
tools = []
|
||||
for t in tooldefs:
|
||||
tools.append(
|
||||
Tool(
|
||||
identifier=t.name,
|
||||
toolgroup_id=group_id,
|
||||
description=t.description or "",
|
||||
parameters=t.parameters or [],
|
||||
metadata=t.metadata,
|
||||
provider_id=toolgroup.provider_id,
|
||||
)
|
||||
)
|
||||
|
||||
self.toolgroups_to_tools[group_id] = tools
|
||||
for tool in tools:
|
||||
self.tool_to_toolgroup[tool.identifier] = group_id
|
||||
all_tools.extend(self.toolgroups_to_tools[group_id])
|
||||
|
||||
return ListToolsResponse(data=all_tools)
|
||||
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||
|
@ -36,7 +73,13 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
return tool_group
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
return await self.get_object_by_identifier("tool", tool_name)
|
||||
if tool_name in self.tool_to_toolgroup:
|
||||
toolgroup_id = self.tool_to_toolgroup[tool_name]
|
||||
tools = self.toolgroups_to_tools[toolgroup_id]
|
||||
for tool in tools:
|
||||
if tool.identifier == tool_name:
|
||||
return tool
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
async def register_tool_group(
|
||||
self,
|
||||
|
@ -45,53 +88,20 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
mcp_endpoint: URL | None = None,
|
||||
args: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
tools = []
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
||||
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||
|
||||
for tool_def in tool_defs.data:
|
||||
tools.append(
|
||||
ToolWithACL(
|
||||
identifier=tool_def.name,
|
||||
toolgroup_id=toolgroup_id,
|
||||
description=tool_def.description or "",
|
||||
parameters=tool_def.parameters or [],
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=tool_def.name,
|
||||
metadata=tool_def.metadata,
|
||||
tool_host=tool_host,
|
||||
)
|
||||
)
|
||||
for tool in tools:
|
||||
existing_tool = await self.get_tool(tool.identifier)
|
||||
# Compare existing and new object if one exists
|
||||
if existing_tool:
|
||||
existing_dict = existing_tool.model_dump()
|
||||
new_dict = tool.model_dump()
|
||||
|
||||
if existing_dict != new_dict:
|
||||
raise ValueError(
|
||||
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
|
||||
)
|
||||
await self.register_object(tool)
|
||||
|
||||
await self.dist_registry.register(
|
||||
ToolGroupWithACL(
|
||||
identifier=toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
mcp_endpoint=mcp_endpoint,
|
||||
args=args,
|
||||
)
|
||||
toolgroup = ToolGroupWithACL(
|
||||
identifier=toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
mcp_endpoint=mcp_endpoint,
|
||||
args=args,
|
||||
)
|
||||
await self.register_object(toolgroup)
|
||||
return toolgroup
|
||||
|
||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||
tool_group = await self.get_tool_group(toolgroup_id)
|
||||
if tool_group is None:
|
||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||
tools = await self.list_tools(toolgroup_id)
|
||||
for tool in getattr(tools, "data", []):
|
||||
await self.unregister_object(tool)
|
||||
await self.unregister_object(tool_group)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue