fix(tools): do not index tools, only index toolgroups

This commit is contained in:
Ashwin Bharambe 2025-05-25 00:20:36 -07:00
parent 298721c238
commit bf8a73e09a
6 changed files with 43 additions and 77 deletions

View file

@ -9555,9 +9555,6 @@
"toolgroup_id": {
"type": "string"
},
"tool_host": {
"$ref": "#/components/schemas/ToolHost"
},
"description": {
"type": "string"
},
@ -9599,21 +9596,11 @@
"provider_id",
"type",
"toolgroup_id",
"tool_host",
"description",
"parameters"
],
"title": "Tool"
},
"ToolHost": {
"type": "string",
"enum": [
"distribution",
"client",
"model_context_protocol"
],
"title": "ToolHost"
},
"ToolGroup": {
"type": "object",
"properties": {

View file

@ -6713,8 +6713,6 @@ components:
default: tool
toolgroup_id:
type: string
tool_host:
$ref: '#/components/schemas/ToolHost'
description:
type: string
parameters:
@ -6737,17 +6735,9 @@ components:
- provider_id
- type
- toolgroup_id
- tool_host
- description
- parameters
title: Tool
ToolHost:
type: string
enum:
- distribution
- client
- model_context_protocol
title: ToolHost
ToolGroup:
type: object
properties:

View file

@ -27,18 +27,10 @@ class ToolParameter(BaseModel):
default: Any | None = None
@json_schema_type
class ToolHost(Enum):
distribution = "distribution"
client = "client"
model_context_protocol = "model_context_protocol"
@json_schema_type
class Tool(Resource):
type: Literal[ResourceType.tool] = ResourceType.tool
toolgroup_id: str
tool_host: ToolHost
description: str
parameters: list[ToolParameter]
metadata: dict[str, Any] | None = None

View file

@ -19,7 +19,8 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core")
@ -28,7 +29,7 @@ class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: RoutingTable,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
@ -59,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime):
def __init__(
self,
routing_table: RoutingTable,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table

View file

@ -7,11 +7,8 @@
from typing import Any
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups, ToolHost
from llama_stack.distribution.datatypes import (
ToolGroupWithACL,
ToolWithACL,
)
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
from llama_stack.distribution.datatypes import ToolGroupWithACL
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
@ -20,11 +17,37 @@ logger = get_logger(name=__name__, category="core")
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
toolgroups_to_tools: dict[str, list[Tool]] = {}
tool_to_toolgroup: dict[str, str] = {}
# overridden
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
tool_name = routing_key
if tool_name in self.tool_to_toolgroup:
routing_key = self.tool_to_toolgroup[tool_name]
return super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
tools = await self.get_all_with_type("tool")
if toolgroup_id:
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
return ListToolsResponse(data=tools)
toolgroups = [await self.get_tool_group(toolgroup_id)]
else:
toolgroups = await self.get_all_with_type("tool_group")
all_tools = []
for toolgroup in toolgroups:
group_id = toolgroup.identifier
if group_id not in self.toolgroups_to_tools:
provider_impl = self.get_provider_impl(toolgroup.provider_id)
tools = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint)
self.toolgroups_to_tools[group_id] = tools.data
for tool in tools.data:
self.tool_to_toolgroup[tool.identifier] = group_id
all_tools.extend(self.toolgroups_to_tools[group_id])
return ListToolsResponse(data=all_tools)
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
@ -36,7 +59,13 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return tool_group
async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name)
if tool_name in self.tool_to_toolgroup:
toolgroup_id = self.tool_to_toolgroup[tool_name]
tools = self.toolgroups_to_tools[toolgroup_id]
for tool in tools:
if tool.identifier == tool_name:
return tool
raise ValueError(f"Tool '{tool_name}' not found")
async def register_tool_group(
self,
@ -45,36 +74,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
for tool_def in tool_defs.data:
tools.append(
ToolWithACL(
identifier=tool_def.name,
toolgroup_id=toolgroup_id,
description=tool_def.description or "",
parameters=tool_def.parameters or [],
provider_id=provider_id,
provider_resource_id=tool_def.name,
metadata=tool_def.metadata,
tool_host=tool_host,
)
)
for tool in tools:
existing_tool = await self.get_tool(tool.identifier)
# Compare existing and new object if one exists
if existing_tool:
existing_dict = existing_tool.model_dump()
new_dict = tool.model_dump()
if existing_dict != new_dict:
raise ValueError(
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
)
await self.register_object(tool)
await self.dist_registry.register(
ToolGroupWithACL(
identifier=toolgroup_id,
@ -89,9 +88,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found")
tools = await self.list_tools(toolgroup_id)
for tool in getattr(tools, "data", []):
await self.unregister_object(tool)
await self.unregister_object(tool_group)
async def shutdown(self) -> None:

View file

@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v8"
KEY_VERSION = "v9"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"