mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
fix(tools): do not index tools, only index toolgroups
This commit is contained in:
parent
298721c238
commit
bf8a73e09a
6 changed files with 43 additions and 77 deletions
13
docs/_static/llama-stack-spec.html
vendored
13
docs/_static/llama-stack-spec.html
vendored
|
@ -9555,9 +9555,6 @@
|
||||||
"toolgroup_id": {
|
"toolgroup_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"tool_host": {
|
|
||||||
"$ref": "#/components/schemas/ToolHost"
|
|
||||||
},
|
|
||||||
"description": {
|
"description": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
|
@ -9599,21 +9596,11 @@
|
||||||
"provider_id",
|
"provider_id",
|
||||||
"type",
|
"type",
|
||||||
"toolgroup_id",
|
"toolgroup_id",
|
||||||
"tool_host",
|
|
||||||
"description",
|
"description",
|
||||||
"parameters"
|
"parameters"
|
||||||
],
|
],
|
||||||
"title": "Tool"
|
"title": "Tool"
|
||||||
},
|
},
|
||||||
"ToolHost": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"distribution",
|
|
||||||
"client",
|
|
||||||
"model_context_protocol"
|
|
||||||
],
|
|
||||||
"title": "ToolHost"
|
|
||||||
},
|
|
||||||
"ToolGroup": {
|
"ToolGroup": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
10
docs/_static/llama-stack-spec.yaml
vendored
10
docs/_static/llama-stack-spec.yaml
vendored
|
@ -6713,8 +6713,6 @@ components:
|
||||||
default: tool
|
default: tool
|
||||||
toolgroup_id:
|
toolgroup_id:
|
||||||
type: string
|
type: string
|
||||||
tool_host:
|
|
||||||
$ref: '#/components/schemas/ToolHost'
|
|
||||||
description:
|
description:
|
||||||
type: string
|
type: string
|
||||||
parameters:
|
parameters:
|
||||||
|
@ -6737,17 +6735,9 @@ components:
|
||||||
- provider_id
|
- provider_id
|
||||||
- type
|
- type
|
||||||
- toolgroup_id
|
- toolgroup_id
|
||||||
- tool_host
|
|
||||||
- description
|
- description
|
||||||
- parameters
|
- parameters
|
||||||
title: Tool
|
title: Tool
|
||||||
ToolHost:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- distribution
|
|
||||||
- client
|
|
||||||
- model_context_protocol
|
|
||||||
title: ToolHost
|
|
||||||
ToolGroup:
|
ToolGroup:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
@ -27,18 +27,10 @@ class ToolParameter(BaseModel):
|
||||||
default: Any | None = None
|
default: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolHost(Enum):
|
|
||||||
distribution = "distribution"
|
|
||||||
client = "client"
|
|
||||||
model_context_protocol = "model_context_protocol"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Tool(Resource):
|
class Tool(Resource):
|
||||||
type: Literal[ResourceType.tool] = ResourceType.tool
|
type: Literal[ResourceType.tool] = ResourceType.tool
|
||||||
toolgroup_id: str
|
toolgroup_id: str
|
||||||
tool_host: ToolHost
|
|
||||||
description: str
|
description: str
|
||||||
parameters: list[ToolParameter]
|
parameters: list[ToolParameter]
|
||||||
metadata: dict[str, Any] | None = None
|
metadata: dict[str, Any] | None = None
|
||||||
|
|
|
@ -19,7 +19,8 @@ from llama_stack.apis.tools import (
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
|
||||||
|
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
@ -28,7 +29,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
class RagToolImpl(RAGToolRuntime):
|
class RagToolImpl(RAGToolRuntime):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: ToolGroupsRoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
@ -59,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: ToolGroupsRoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("Initializing ToolRuntimeRouter")
|
logger.debug("Initializing ToolRuntimeRouter")
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
|
@ -7,11 +7,8 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
from llama_stack.apis.common.content_types import URL
|
||||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups, ToolHost
|
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import ToolGroupWithACL
|
||||||
ToolGroupWithACL,
|
|
||||||
ToolWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
@ -20,11 +17,37 @@ logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
|
toolgroups_to_tools: dict[str, list[Tool]] = {}
|
||||||
|
tool_to_toolgroup: dict[str, str] = {}
|
||||||
|
|
||||||
|
# overridden
|
||||||
|
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||||
|
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
|
||||||
|
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
|
||||||
|
tool_name = routing_key
|
||||||
|
if tool_name in self.tool_to_toolgroup:
|
||||||
|
routing_key = self.tool_to_toolgroup[tool_name]
|
||||||
|
return super().get_provider_impl(routing_key, provider_id)
|
||||||
|
|
||||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||||
tools = await self.get_all_with_type("tool")
|
|
||||||
if toolgroup_id:
|
if toolgroup_id:
|
||||||
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
toolgroups = [await self.get_tool_group(toolgroup_id)]
|
||||||
return ListToolsResponse(data=tools)
|
else:
|
||||||
|
toolgroups = await self.get_all_with_type("tool_group")
|
||||||
|
|
||||||
|
all_tools = []
|
||||||
|
for toolgroup in toolgroups:
|
||||||
|
group_id = toolgroup.identifier
|
||||||
|
if group_id not in self.toolgroups_to_tools:
|
||||||
|
provider_impl = self.get_provider_impl(toolgroup.provider_id)
|
||||||
|
tools = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint)
|
||||||
|
|
||||||
|
self.toolgroups_to_tools[group_id] = tools.data
|
||||||
|
for tool in tools.data:
|
||||||
|
self.tool_to_toolgroup[tool.identifier] = group_id
|
||||||
|
all_tools.extend(self.toolgroups_to_tools[group_id])
|
||||||
|
|
||||||
|
return ListToolsResponse(data=all_tools)
|
||||||
|
|
||||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||||
|
@ -36,7 +59,13 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
return tool_group
|
return tool_group
|
||||||
|
|
||||||
async def get_tool(self, tool_name: str) -> Tool:
|
async def get_tool(self, tool_name: str) -> Tool:
|
||||||
return await self.get_object_by_identifier("tool", tool_name)
|
if tool_name in self.tool_to_toolgroup:
|
||||||
|
toolgroup_id = self.tool_to_toolgroup[tool_name]
|
||||||
|
tools = self.toolgroups_to_tools[toolgroup_id]
|
||||||
|
for tool in tools:
|
||||||
|
if tool.identifier == tool_name:
|
||||||
|
return tool
|
||||||
|
raise ValueError(f"Tool '{tool_name}' not found")
|
||||||
|
|
||||||
async def register_tool_group(
|
async def register_tool_group(
|
||||||
self,
|
self,
|
||||||
|
@ -45,36 +74,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
mcp_endpoint: URL | None = None,
|
mcp_endpoint: URL | None = None,
|
||||||
args: dict[str, Any] | None = None,
|
args: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
tools = []
|
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
|
||||||
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
|
||||||
|
|
||||||
for tool_def in tool_defs.data:
|
|
||||||
tools.append(
|
|
||||||
ToolWithACL(
|
|
||||||
identifier=tool_def.name,
|
|
||||||
toolgroup_id=toolgroup_id,
|
|
||||||
description=tool_def.description or "",
|
|
||||||
parameters=tool_def.parameters or [],
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=tool_def.name,
|
|
||||||
metadata=tool_def.metadata,
|
|
||||||
tool_host=tool_host,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for tool in tools:
|
|
||||||
existing_tool = await self.get_tool(tool.identifier)
|
|
||||||
# Compare existing and new object if one exists
|
|
||||||
if existing_tool:
|
|
||||||
existing_dict = existing_tool.model_dump()
|
|
||||||
new_dict = tool.model_dump()
|
|
||||||
|
|
||||||
if existing_dict != new_dict:
|
|
||||||
raise ValueError(
|
|
||||||
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
|
|
||||||
)
|
|
||||||
await self.register_object(tool)
|
|
||||||
|
|
||||||
await self.dist_registry.register(
|
await self.dist_registry.register(
|
||||||
ToolGroupWithACL(
|
ToolGroupWithACL(
|
||||||
identifier=toolgroup_id,
|
identifier=toolgroup_id,
|
||||||
|
@ -89,9 +88,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
tool_group = await self.get_tool_group(toolgroup_id)
|
tool_group = await self.get_tool_group(toolgroup_id)
|
||||||
if tool_group is None:
|
if tool_group is None:
|
||||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||||
tools = await self.list_tools(toolgroup_id)
|
|
||||||
for tool in getattr(tools, "data", []):
|
|
||||||
await self.unregister_object(tool)
|
|
||||||
await self.unregister_object(tool_group)
|
await self.unregister_object(tool_group)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
|
|
@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v8"
|
KEY_VERSION = "v9"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue