From bf8a73e09a041094806192fb15f64ef9b52128af Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sun, 25 May 2025 00:20:36 -0700 Subject: [PATCH] fix(tools): do not index tools, only index toolgroups --- docs/_static/llama-stack-spec.html | 13 --- docs/_static/llama-stack-spec.yaml | 10 --- llama_stack/apis/tools/tools.py | 8 -- .../distribution/routers/tool_runtime.py | 7 +- .../distribution/routing_tables/toolgroups.py | 80 +++++++++---------- llama_stack/distribution/store/registry.py | 2 +- 6 files changed, 43 insertions(+), 77 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 99ae1c038..043e9467e 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9555,9 +9555,6 @@ "toolgroup_id": { "type": "string" }, - "tool_host": { - "$ref": "#/components/schemas/ToolHost" - }, "description": { "type": "string" }, @@ -9599,21 +9596,11 @@ "provider_id", "type", "toolgroup_id", - "tool_host", "description", "parameters" ], "title": "Tool" }, - "ToolHost": { - "type": "string", - "enum": [ - "distribution", - "client", - "model_context_protocol" - ], - "title": "ToolHost" - }, "ToolGroup": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 4e4f09eb0..c7ec8db5f 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6713,8 +6713,6 @@ components: default: tool toolgroup_id: type: string - tool_host: - $ref: '#/components/schemas/ToolHost' description: type: string parameters: @@ -6737,17 +6735,9 @@ components: - provider_id - type - toolgroup_id - - tool_host - description - parameters title: Tool - ToolHost: - type: string - enum: - - distribution - - client - - model_context_protocol - title: ToolHost ToolGroup: type: object properties: diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 29649495c..0c8d47edf 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -27,18 +27,10 @@ class ToolParameter(BaseModel): default: Any | None = None -@json_schema_type -class ToolHost(Enum): - distribution = "distribution" - client = "client" - model_context_protocol = "model_context_protocol" - - @json_schema_type class Tool(Resource): type: Literal[ResourceType.tool] = ResourceType.tool toolgroup_id: str - tool_host: ToolHost description: str parameters: list[ToolParameter] metadata: dict[str, Any] | None = None diff --git a/llama_stack/distribution/routers/tool_runtime.py b/llama_stack/distribution/routers/tool_runtime.py index 2d4734a2e..3cf86db7f 100644 --- a/llama_stack/distribution/routers/tool_runtime.py +++ b/llama_stack/distribution/routers/tool_runtime.py @@ -19,7 +19,8 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.log import get_logger -from llama_stack.providers.datatypes import RoutingTable + +from ..routing_tables.toolgroups import ToolGroupsRoutingTable logger = get_logger(name=__name__, category="core") @@ -28,7 +29,7 @@ class ToolRuntimeRouter(ToolRuntime): class RagToolImpl(RAGToolRuntime): def __init__( self, - routing_table: RoutingTable, + routing_table: ToolGroupsRoutingTable, ) -> None: logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") self.routing_table = routing_table @@ -59,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime): def __init__( self, - routing_table: RoutingTable, + routing_table: ToolGroupsRoutingTable, ) -> None: logger.debug("Initializing ToolRuntimeRouter") self.routing_table = routing_table diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py index cb73dc7c2..0cf7c9a45 100644 --- a/llama_stack/distribution/routing_tables/toolgroups.py +++ b/llama_stack/distribution/routing_tables/toolgroups.py @@ -7,11 +7,8 @@ from typing import Any from llama_stack.apis.common.content_types import URL -from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups, ToolHost -from llama_stack.distribution.datatypes import ( - ToolGroupWithACL, - ToolWithACL, -) +from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups +from llama_stack.distribution.datatypes import ToolGroupWithACL from llama_stack.log import get_logger from .common import CommonRoutingTableImpl @@ -20,11 +17,37 @@ logger = get_logger(name=__name__, category="core") class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): + toolgroups_to_tools: dict[str, list[Tool]] = {} + tool_to_toolgroup: dict[str, str] = {} + + # overridden + def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: + # we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id + # TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while? + tool_name = routing_key + if tool_name in self.tool_to_toolgroup: + routing_key = self.tool_to_toolgroup[tool_name] + return super().get_provider_impl(routing_key, provider_id) + async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: - tools = await self.get_all_with_type("tool") if toolgroup_id: - tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id] - return ListToolsResponse(data=tools) + toolgroups = [await self.get_tool_group(toolgroup_id)] + else: + toolgroups = await self.get_all_with_type("tool_group") + + all_tools = [] + for toolgroup in toolgroups: + group_id = toolgroup.identifier + if group_id not in self.toolgroups_to_tools: + provider_impl = self.get_provider_impl(toolgroup.provider_id) + tools = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint) + + self.toolgroups_to_tools[group_id] = tools.data + for tool in tools.data: + self.tool_to_toolgroup[tool.identifier] = group_id + all_tools.extend(self.toolgroups_to_tools[group_id]) + + return ListToolsResponse(data=all_tools) async def list_tool_groups(self) -> ListToolGroupsResponse: return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) @@ -36,7 +59,13 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): return tool_group async def get_tool(self, tool_name: str) -> Tool: - return await self.get_object_by_identifier("tool", tool_name) + if tool_name in self.tool_to_toolgroup: + toolgroup_id = self.tool_to_toolgroup[tool_name] + tools = self.toolgroups_to_tools[toolgroup_id] + for tool in tools: + if tool.identifier == tool_name: + return tool + raise ValueError(f"Tool '{tool_name}' not found") async def register_tool_group( self, @@ -45,36 +74,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): mcp_endpoint: URL | None = None, args: dict[str, Any] | None = None, ) -> None: - tools = [] - tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) - tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution - - for tool_def in tool_defs.data: - tools.append( - ToolWithACL( - identifier=tool_def.name, - toolgroup_id=toolgroup_id, - description=tool_def.description or "", - parameters=tool_def.parameters or [], - provider_id=provider_id, - provider_resource_id=tool_def.name, - metadata=tool_def.metadata, - tool_host=tool_host, - ) - ) - for tool in tools: - existing_tool = await self.get_tool(tool.identifier) - # Compare existing and new object if one exists - if existing_tool: - existing_dict = existing_tool.model_dump() - new_dict = tool.model_dump() - - if existing_dict != new_dict: - raise ValueError( - f"Object {tool.identifier} already exists in registry. Please use a different identifier." - ) - await self.register_object(tool) - await self.dist_registry.register( ToolGroupWithACL( identifier=toolgroup_id, @@ -89,9 +88,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): tool_group = await self.get_tool_group(toolgroup_id) if tool_group is None: raise ValueError(f"Tool group {toolgroup_id} not found") - tools = await self.list_tools(toolgroup_id) - for tool in getattr(tools, "data", []): - await self.unregister_object(tool) await self.unregister_object(tool_group) async def shutdown(self) -> None: diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index a6b400136..0e84854c2 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -36,7 +36,7 @@ class DistributionRegistry(Protocol): REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v8" +KEY_VERSION = "v9" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"