diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py index 3f103ed22..2f7dc3e06 100644 --- a/llama_stack/distribution/routing_tables/toolgroups.py +++ b/llama_stack/distribution/routing_tables/toolgroups.py @@ -16,6 +16,15 @@ from .common import CommonRoutingTableImpl logger = get_logger(name=__name__, category="core") +def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None: + # handle the funny case like "builtin::rag/knowledge_search" + parts = toolgroup_name_with_maybe_tool_name.split("/") + if len(parts) == 2: + return parts[0] + else: + return None + + class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): toolgroups_to_tools: dict[str, list[Tool]] = {} tool_to_toolgroup: dict[str, str] = {} @@ -24,45 +33,54 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: # we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id # TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while? + + toolgroup_id = parse_toolgroup_from_toolgroup_name_pair(routing_key) + if toolgroup_id: + routing_key = toolgroup_id + if routing_key in self.tool_to_toolgroup: routing_key = self.tool_to_toolgroup[routing_key] return super().get_provider_impl(routing_key, provider_id) async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: if toolgroup_id: + if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id): + toolgroup_id = group_id toolgroups = [await self.get_tool_group(toolgroup_id)] else: toolgroups = await self.get_all_with_type("tool_group") all_tools = [] for toolgroup in toolgroups: - group_id = toolgroup.identifier - if group_id not in self.toolgroups_to_tools: - provider_impl = super().get_provider_impl(group_id, toolgroup.provider_id) - tooldefs_response = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint) - - # TODO: kill this Tool vs ToolDef distinction - tooldefs = tooldefs_response.data - tools = [] - for t in tooldefs: - tools.append( - Tool( - identifier=t.name, - toolgroup_id=group_id, - description=t.description or "", - parameters=t.parameters or [], - metadata=t.metadata, - provider_id=toolgroup.provider_id, - ) - ) - - self.toolgroups_to_tools[group_id] = tools - for tool in tools: - self.tool_to_toolgroup[tool.identifier] = group_id - all_tools.extend(self.toolgroups_to_tools[group_id]) + if toolgroup.identifier not in self.toolgroups_to_tools: + await self._index_tools(toolgroup) + all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier]) return ListToolsResponse(data=all_tools) + async def _index_tools(self, toolgroup: ToolGroup): + provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id) + tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint) + + # TODO: kill this Tool vs ToolDef distinction + tooldefs = tooldefs_response.data + tools = [] + for t in tooldefs: + tools.append( + Tool( + identifier=t.name, + toolgroup_id=toolgroup.identifier, + description=t.description or "", + parameters=t.parameters or [], + metadata=t.metadata, + provider_id=toolgroup.provider_id, + ) + ) + + self.toolgroups_to_tools[toolgroup.identifier] = tools + for tool in tools: + self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier + async def list_tool_groups(self) -> ListToolGroupsResponse: return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) @@ -96,6 +114,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): args=args, ) await self.register_object(toolgroup) + + # ideally, indexing of the tools should not be necessary because anyone using + # the tools should first list the tools and then use them. but there are assumptions + # baked in some of the code and tests right now. + if not toolgroup.mcp_endpoint: + await self._index_tools(toolgroup) return toolgroup async def unregister_toolgroup(self, toolgroup_id: str) -> None: