mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
fix(tools): do not index tools, only index toolgroups
This commit is contained in:
parent
298721c238
commit
bf8a73e09a
6 changed files with 43 additions and 77 deletions
13
docs/_static/llama-stack-spec.html
vendored
13
docs/_static/llama-stack-spec.html
vendored
|
@ -9555,9 +9555,6 @@
|
|||
"toolgroup_id": {
|
||||
"type": "string"
|
||||
},
|
||||
"tool_host": {
|
||||
"$ref": "#/components/schemas/ToolHost"
|
||||
},
|
||||
"description": {
|
||||
"type": "string"
|
||||
},
|
||||
|
@ -9599,21 +9596,11 @@
|
|||
"provider_id",
|
||||
"type",
|
||||
"toolgroup_id",
|
||||
"tool_host",
|
||||
"description",
|
||||
"parameters"
|
||||
],
|
||||
"title": "Tool"
|
||||
},
|
||||
"ToolHost": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"distribution",
|
||||
"client",
|
||||
"model_context_protocol"
|
||||
],
|
||||
"title": "ToolHost"
|
||||
},
|
||||
"ToolGroup": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
10
docs/_static/llama-stack-spec.yaml
vendored
10
docs/_static/llama-stack-spec.yaml
vendored
|
@ -6713,8 +6713,6 @@ components:
|
|||
default: tool
|
||||
toolgroup_id:
|
||||
type: string
|
||||
tool_host:
|
||||
$ref: '#/components/schemas/ToolHost'
|
||||
description:
|
||||
type: string
|
||||
parameters:
|
||||
|
@ -6737,17 +6735,9 @@ components:
|
|||
- provider_id
|
||||
- type
|
||||
- toolgroup_id
|
||||
- tool_host
|
||||
- description
|
||||
- parameters
|
||||
title: Tool
|
||||
ToolHost:
|
||||
type: string
|
||||
enum:
|
||||
- distribution
|
||||
- client
|
||||
- model_context_protocol
|
||||
title: ToolHost
|
||||
ToolGroup:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
@ -27,18 +27,10 @@ class ToolParameter(BaseModel):
|
|||
default: Any | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolHost(Enum):
|
||||
distribution = "distribution"
|
||||
client = "client"
|
||||
model_context_protocol = "model_context_protocol"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Tool(Resource):
|
||||
type: Literal[ResourceType.tool] = ResourceType.tool
|
||||
toolgroup_id: str
|
||||
tool_host: ToolHost
|
||||
description: str
|
||||
parameters: list[ToolParameter]
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
|
|
@ -19,7 +19,8 @@ from llama_stack.apis.tools import (
|
|||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
@ -28,7 +29,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
class RagToolImpl(RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
routing_table: ToolGroupsRoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||
self.routing_table = routing_table
|
||||
|
@ -59,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
routing_table: ToolGroupsRoutingTable,
|
||||
) -> None:
|
||||
logger.debug("Initializing ToolRuntimeRouter")
|
||||
self.routing_table = routing_table
|
||||
|
|
|
@ -7,11 +7,8 @@
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups, ToolHost
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ToolGroupWithACL,
|
||||
ToolWithACL,
|
||||
)
|
||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||
from llama_stack.distribution.datatypes import ToolGroupWithACL
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
@ -20,11 +17,37 @@ logger = get_logger(name=__name__, category="core")
|
|||
|
||||
|
||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||
toolgroups_to_tools: dict[str, list[Tool]] = {}
|
||||
tool_to_toolgroup: dict[str, str] = {}
|
||||
|
||||
# overridden
|
||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
||||
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
||||
tool_name = routing_key
|
||||
if tool_name in self.tool_to_toolgroup:
|
||||
routing_key = self.tool_to_toolgroup[tool_name]
|
||||
return super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
tools = await self.get_all_with_type("tool")
|
||||
if toolgroup_id:
|
||||
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
||||
return ListToolsResponse(data=tools)
|
||||
toolgroups = [await self.get_tool_group(toolgroup_id)]
|
||||
else:
|
||||
toolgroups = await self.get_all_with_type("tool_group")
|
||||
|
||||
all_tools = []
|
||||
for toolgroup in toolgroups:
|
||||
group_id = toolgroup.identifier
|
||||
if group_id not in self.toolgroups_to_tools:
|
||||
provider_impl = self.get_provider_impl(toolgroup.provider_id)
|
||||
tools = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint)
|
||||
|
||||
self.toolgroups_to_tools[group_id] = tools.data
|
||||
for tool in tools.data:
|
||||
self.tool_to_toolgroup[tool.identifier] = group_id
|
||||
all_tools.extend(self.toolgroups_to_tools[group_id])
|
||||
|
||||
return ListToolsResponse(data=all_tools)
|
||||
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||
|
@ -36,7 +59,13 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
return tool_group
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
return await self.get_object_by_identifier("tool", tool_name)
|
||||
if tool_name in self.tool_to_toolgroup:
|
||||
toolgroup_id = self.tool_to_toolgroup[tool_name]
|
||||
tools = self.toolgroups_to_tools[toolgroup_id]
|
||||
for tool in tools:
|
||||
if tool.identifier == tool_name:
|
||||
return tool
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
async def register_tool_group(
|
||||
self,
|
||||
|
@ -45,36 +74,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
mcp_endpoint: URL | None = None,
|
||||
args: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
tools = []
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
||||
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||
|
||||
for tool_def in tool_defs.data:
|
||||
tools.append(
|
||||
ToolWithACL(
|
||||
identifier=tool_def.name,
|
||||
toolgroup_id=toolgroup_id,
|
||||
description=tool_def.description or "",
|
||||
parameters=tool_def.parameters or [],
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=tool_def.name,
|
||||
metadata=tool_def.metadata,
|
||||
tool_host=tool_host,
|
||||
)
|
||||
)
|
||||
for tool in tools:
|
||||
existing_tool = await self.get_tool(tool.identifier)
|
||||
# Compare existing and new object if one exists
|
||||
if existing_tool:
|
||||
existing_dict = existing_tool.model_dump()
|
||||
new_dict = tool.model_dump()
|
||||
|
||||
if existing_dict != new_dict:
|
||||
raise ValueError(
|
||||
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
|
||||
)
|
||||
await self.register_object(tool)
|
||||
|
||||
await self.dist_registry.register(
|
||||
ToolGroupWithACL(
|
||||
identifier=toolgroup_id,
|
||||
|
@ -89,9 +88,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
tool_group = await self.get_tool_group(toolgroup_id)
|
||||
if tool_group is None:
|
||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||
tools = await self.list_tools(toolgroup_id)
|
||||
for tool in getattr(tools, "data", []):
|
||||
await self.unregister_object(tool)
|
||||
await self.unregister_object(tool_group)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
|
|
|
@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
|
|||
|
||||
|
||||
REGISTER_PREFIX = "distributions:registry"
|
||||
KEY_VERSION = "v8"
|
||||
KEY_VERSION = "v9"
|
||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue