diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 99ae1c038..043e9467e 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -9555,9 +9555,6 @@
"toolgroup_id": {
"type": "string"
},
- "tool_host": {
- "$ref": "#/components/schemas/ToolHost"
- },
"description": {
"type": "string"
},
@@ -9599,21 +9596,11 @@
"provider_id",
"type",
"toolgroup_id",
- "tool_host",
"description",
"parameters"
],
"title": "Tool"
},
- "ToolHost": {
- "type": "string",
- "enum": [
- "distribution",
- "client",
- "model_context_protocol"
- ],
- "title": "ToolHost"
- },
"ToolGroup": {
"type": "object",
"properties": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 4e4f09eb0..c7ec8db5f 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -6713,8 +6713,6 @@ components:
default: tool
toolgroup_id:
type: string
- tool_host:
- $ref: '#/components/schemas/ToolHost'
description:
type: string
parameters:
@@ -6737,17 +6735,9 @@ components:
- provider_id
- type
- toolgroup_id
- - tool_host
- description
- parameters
title: Tool
- ToolHost:
- type: string
- enum:
- - distribution
- - client
- - model_context_protocol
- title: ToolHost
ToolGroup:
type: object
properties:
diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py
index 29649495c..0c8d47edf 100644
--- a/llama_stack/apis/tools/tools.py
+++ b/llama_stack/apis/tools/tools.py
@@ -27,18 +27,10 @@ class ToolParameter(BaseModel):
default: Any | None = None
-@json_schema_type
-class ToolHost(Enum):
- distribution = "distribution"
- client = "client"
- model_context_protocol = "model_context_protocol"
-
-
@json_schema_type
class Tool(Resource):
type: Literal[ResourceType.tool] = ResourceType.tool
toolgroup_id: str
- tool_host: ToolHost
description: str
parameters: list[ToolParameter]
metadata: dict[str, Any] | None = None
diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py
index 8b846d051..b7c7cb87f 100644
--- a/llama_stack/distribution/resolver.py
+++ b/llama_stack/distribution/resolver.py
@@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import (
RemoteProviderSpec,
ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate,
- ToolsProtocolPrivate,
+ ToolGroupsProtocolPrivate,
VectorDBsProtocolPrivate,
)
@@ -93,7 +93,7 @@ def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
def additional_protocols_map() -> dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
- Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
+ Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
diff --git a/llama_stack/distribution/routers/tool_runtime.py b/llama_stack/distribution/routers/tool_runtime.py
index 2d4734a2e..285843dbc 100644
--- a/llama_stack/distribution/routers/tool_runtime.py
+++ b/llama_stack/distribution/routers/tool_runtime.py
@@ -11,7 +11,7 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.tools import (
- ListToolDefsResponse,
+ ListToolsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
@@ -19,7 +19,8 @@ from llama_stack.apis.tools import (
ToolRuntime,
)
from llama_stack.log import get_logger
-from llama_stack.providers.datatypes import RoutingTable
+
+from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core")
@@ -28,7 +29,7 @@ class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
- routing_table: RoutingTable,
+ routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
@@ -59,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime):
def __init__(
self,
- routing_table: RoutingTable,
+ routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
@@ -86,6 +87,6 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
- ) -> ListToolDefsResponse:
+ ) -> ListToolsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
- return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
+ return await self.routing_table.list_tools(tool_group_id)
diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py
index 95a92a5ba..8ec87ca50 100644
--- a/llama_stack/distribution/routing_tables/common.py
+++ b/llama_stack/distribution/routing_tables/common.py
@@ -46,7 +46,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
elif api == Api.eval:
return await p.register_benchmark(obj)
elif api == Api.tool_runtime:
- return await p.register_tool(obj)
+ return await p.register_toolgroup(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
@@ -60,7 +60,7 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime:
- return await p.unregister_tool(obj.identifier)
+ return await p.unregister_toolgroup(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
@@ -136,7 +136,7 @@ class CommonRoutingTableImpl(RoutingTable):
elif isinstance(self, BenchmarksRoutingTable):
return ("Eval", "benchmark")
elif isinstance(self, ToolGroupsRoutingTable):
- return ("Tools", "tool")
+ return ("ToolGroups", "tool_group")
else:
raise ValueError("Unknown routing table type")
diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py
index cb73dc7c2..3f103ed22 100644
--- a/llama_stack/distribution/routing_tables/toolgroups.py
+++ b/llama_stack/distribution/routing_tables/toolgroups.py
@@ -7,11 +7,8 @@
from typing import Any
from llama_stack.apis.common.content_types import URL
-from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups, ToolHost
-from llama_stack.distribution.datatypes import (
- ToolGroupWithACL,
- ToolWithACL,
-)
+from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
+from llama_stack.distribution.datatypes import ToolGroupWithACL
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
@@ -20,11 +17,51 @@ logger = get_logger(name=__name__, category="core")
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
+ toolgroups_to_tools: dict[str, list[Tool]] = {}
+ tool_to_toolgroup: dict[str, str] = {}
+
+ # overridden
+ def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
+ # we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
+ # TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
+ if routing_key in self.tool_to_toolgroup:
+ routing_key = self.tool_to_toolgroup[routing_key]
+ return super().get_provider_impl(routing_key, provider_id)
+
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
- tools = await self.get_all_with_type("tool")
if toolgroup_id:
- tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
- return ListToolsResponse(data=tools)
+ toolgroups = [await self.get_tool_group(toolgroup_id)]
+ else:
+ toolgroups = await self.get_all_with_type("tool_group")
+
+ all_tools = []
+ for toolgroup in toolgroups:
+ group_id = toolgroup.identifier
+ if group_id not in self.toolgroups_to_tools:
+ provider_impl = super().get_provider_impl(group_id, toolgroup.provider_id)
+ tooldefs_response = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint)
+
+ # TODO: kill this Tool vs ToolDef distinction
+ tooldefs = tooldefs_response.data
+ tools = []
+ for t in tooldefs:
+ tools.append(
+ Tool(
+ identifier=t.name,
+ toolgroup_id=group_id,
+ description=t.description or "",
+ parameters=t.parameters or [],
+ metadata=t.metadata,
+ provider_id=toolgroup.provider_id,
+ )
+ )
+
+ self.toolgroups_to_tools[group_id] = tools
+ for tool in tools:
+ self.tool_to_toolgroup[tool.identifier] = group_id
+ all_tools.extend(self.toolgroups_to_tools[group_id])
+
+ return ListToolsResponse(data=all_tools)
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
@@ -36,7 +73,13 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return tool_group
async def get_tool(self, tool_name: str) -> Tool:
- return await self.get_object_by_identifier("tool", tool_name)
+ if tool_name in self.tool_to_toolgroup:
+ toolgroup_id = self.tool_to_toolgroup[tool_name]
+ tools = self.toolgroups_to_tools[toolgroup_id]
+ for tool in tools:
+ if tool.identifier == tool_name:
+ return tool
+ raise ValueError(f"Tool '{tool_name}' not found")
async def register_tool_group(
self,
@@ -45,53 +88,20 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
- tools = []
- tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
- tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
-
- for tool_def in tool_defs.data:
- tools.append(
- ToolWithACL(
- identifier=tool_def.name,
- toolgroup_id=toolgroup_id,
- description=tool_def.description or "",
- parameters=tool_def.parameters or [],
- provider_id=provider_id,
- provider_resource_id=tool_def.name,
- metadata=tool_def.metadata,
- tool_host=tool_host,
- )
- )
- for tool in tools:
- existing_tool = await self.get_tool(tool.identifier)
- # Compare existing and new object if one exists
- if existing_tool:
- existing_dict = existing_tool.model_dump()
- new_dict = tool.model_dump()
-
- if existing_dict != new_dict:
- raise ValueError(
- f"Object {tool.identifier} already exists in registry. Please use a different identifier."
- )
- await self.register_object(tool)
-
- await self.dist_registry.register(
- ToolGroupWithACL(
- identifier=toolgroup_id,
- provider_id=provider_id,
- provider_resource_id=toolgroup_id,
- mcp_endpoint=mcp_endpoint,
- args=args,
- )
+ toolgroup = ToolGroupWithACL(
+ identifier=toolgroup_id,
+ provider_id=provider_id,
+ provider_resource_id=toolgroup_id,
+ mcp_endpoint=mcp_endpoint,
+ args=args,
)
+ await self.register_object(toolgroup)
+ return toolgroup
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found")
- tools = await self.list_tools(toolgroup_id)
- for tool in getattr(tools, "data", []):
- await self.unregister_object(tool)
await self.unregister_object(tool_group)
async def shutdown(self) -> None:
diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py
index a6b400136..0e84854c2 100644
--- a/llama_stack/distribution/store/registry.py
+++ b/llama_stack/distribution/store/registry.py
@@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry"
-KEY_VERSION = "v8"
+KEY_VERSION = "v9"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py
index 3e9806f23..60b05545b 100644
--- a/llama_stack/providers/datatypes.py
+++ b/llama_stack/providers/datatypes.py
@@ -16,7 +16,7 @@ from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield
-from llama_stack.apis.tools import Tool
+from llama_stack.apis.tools import ToolGroup
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.schema_utils import json_schema_type
@@ -74,10 +74,10 @@ class BenchmarksProtocolPrivate(Protocol):
async def register_benchmark(self, benchmark: Benchmark) -> None: ...
-class ToolsProtocolPrivate(Protocol):
- async def register_tool(self, tool: Tool) -> None: ...
+class ToolGroupsProtocolPrivate(Protocol):
+ async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ...
- async def unregister_tool(self, tool_id: str) -> None: ...
+ async def unregister_toolgroup(self, toolgroup_id: str) -> None: ...
@json_schema_type
diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py
index fe16c76b8..c2d264c91 100644
--- a/llama_stack/providers/inline/tool_runtime/rag/memory.py
+++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py
@@ -25,14 +25,14 @@ from llama_stack.apis.tools import (
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
- Tool,
ToolDef,
+ ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
-from llama_stack.providers.datatypes import ToolsProtocolPrivate
+from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
@@ -49,7 +49,7 @@ def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
-class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
+class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__(
self,
config: RagToolRuntimeConfig,
@@ -66,10 +66,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def shutdown(self):
pass
- async def register_tool(self, tool: Tool) -> None:
+ async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
- async def unregister_tool(self, tool_id: str) -> None:
+ async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
async def insert(
diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py
index 18bec463f..7e82cb6d4 100644
--- a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py
+++ b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py
@@ -12,19 +12,19 @@ import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
- Tool,
ToolDef,
+ ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
-from llama_stack.providers.datatypes import ToolsProtocolPrivate
+from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import BingSearchToolConfig
-class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
+class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: BingSearchToolConfig):
self.config = config
self.url = "https://api.bing.microsoft.com/v7.0/search"
@@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
async def initialize(self):
pass
- async def register_tool(self, tool: Tool) -> None:
+ async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
- async def unregister_tool(self, tool_id: str) -> None:
+ async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
def _get_api_key(self) -> str:
diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
index 355cb98b6..b96b9e59c 100644
--- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
+++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
@@ -11,30 +11,30 @@ import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
- Tool,
ToolDef,
+ ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import BuiltinTool
-from llama_stack.providers.datatypes import ToolsProtocolPrivate
+from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import BraveSearchToolConfig
-class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
+class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: BraveSearchToolConfig):
self.config = config
async def initialize(self):
pass
- async def register_tool(self, tool: Tool) -> None:
+ async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
- async def unregister_tool(self, tool_id: str) -> None:
+ async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
def _get_api_key(self) -> str:
diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py
index 3f0b9a188..9603bf97e 100644
--- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py
+++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py
@@ -11,12 +11,13 @@ from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import (
ListToolDefsResponse,
+ ToolGroup,
ToolInvocationResult,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
-from llama_stack.providers.datatypes import ToolsProtocolPrivate
+from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
from .config import MCPProviderConfig
@@ -24,13 +25,19 @@ from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools")
-class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
+class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config
async def initialize(self):
pass
+ async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
+ pass
+
+ async def unregister_toolgroup(self, toolgroup_id: str) -> None:
+ return
+
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py
index 9d6fcd951..1fe91fd7f 100644
--- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py
+++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py
@@ -12,29 +12,29 @@ import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
- Tool,
ToolDef,
+ ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
-from llama_stack.providers.datatypes import ToolsProtocolPrivate
+from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import TavilySearchToolConfig
-class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
+class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: TavilySearchToolConfig):
self.config = config
async def initialize(self):
pass
- async def register_tool(self, tool: Tool) -> None:
+ async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
- async def unregister_tool(self, tool_id: str) -> None:
+ async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
def _get_api_key(self) -> str:
diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
index a3724e4b4..6e1d0f61d 100644
--- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
+++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
@@ -12,19 +12,19 @@ import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
- Tool,
ToolDef,
+ ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
-from llama_stack.providers.datatypes import ToolsProtocolPrivate
+from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import WolframAlphaToolConfig
-class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
+class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: WolframAlphaToolConfig):
self.config = config
self.url = "https://api.wolframalpha.com/v2/query"
@@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def initialize(self):
pass
- async def register_tool(self, tool: Tool) -> None:
+ async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
- async def unregister_tool(self, tool_id: str) -> None:
+ async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
def _get_api_key(self) -> str:
diff --git a/tests/integration/tool_runtime/test_builtin_tools.py b/tests/integration/tool_runtime/test_builtin_tools.py
index 9edf3afa0..1acf06719 100644
--- a/tests/integration/tool_runtime/test_builtin_tools.py
+++ b/tests/integration/tool_runtime/test_builtin_tools.py
@@ -25,10 +25,12 @@ def test_web_search_tool(llama_stack_client, sample_search_query):
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
+ tools = llama_stack_client.tool_runtime.list_tools()
+ assert any(tool.identifier == "web_search" for tool in tools)
+
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="web_search", kwargs={"query": sample_search_query}
)
-
# Verify the response
assert response.content is not None
assert len(response.content) > 0
@@ -49,11 +51,12 @@ def test_wolfram_alpha_tool(llama_stack_client, sample_wolfram_alpha_query):
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
+ tools = llama_stack_client.tool_runtime.list_tools()
+ assert any(tool.identifier == "wolfram_alpha" for tool in tools)
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
)
- print(response.content)
assert response.content is not None
assert len(response.content) > 0
assert isinstance(response.content, str)
diff --git a/tests/integration/tool_runtime/test_mcp.py b/tests/integration/tool_runtime/test_mcp.py
index dd8a6d823..28b2e43c1 100644
--- a/tests/integration/tool_runtime/test_mcp.py
+++ b/tests/integration/tool_runtime/test_mcp.py
@@ -31,13 +31,12 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
test_toolgroup_id = MCP_TOOLGROUP_ID
uri = mcp_server["server_url"]
- # registering itself should fail since it requires listing tools
- with pytest.raises(Exception, match="Unauthorized"):
- llama_stack_client.toolgroups.register(
- toolgroup_id=test_toolgroup_id,
- provider_id="model-context-protocol",
- mcp_endpoint=dict(uri=uri),
- )
+ # registering should not raise an error anymore even if you don't specify the auth token
+ llama_stack_client.toolgroups.register(
+ toolgroup_id=test_toolgroup_id,
+ provider_id="model-context-protocol",
+ mcp_endpoint=dict(uri=uri),
+ )
provider_data = {
"mcp_headers": {
@@ -50,18 +49,9 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
- try:
- llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers)
- except Exception as e:
- # An error is OK since the toolgroup may not exist
- print(f"Error unregistering toolgroup: {e}")
+ with pytest.raises(Exception, match="Unauthorized"):
+ llama_stack_client.tools.list()
- llama_stack_client.toolgroups.register(
- toolgroup_id=test_toolgroup_id,
- provider_id="model-context-protocol",
- mcp_endpoint=dict(uri=uri),
- extra_headers=auth_headers,
- )
response = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers,
diff --git a/tests/integration/tool_runtime/test_registration.py b/tests/integration/tool_runtime/test_registration.py
index b8cbd964a..0846f8c89 100644
--- a/tests/integration/tool_runtime/test_registration.py
+++ b/tests/integration/tool_runtime/test_registration.py
@@ -51,7 +51,5 @@ def test_register_and_unregister_toolgroup(llama_stack_client):
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
- # Verify tools are also unregistered
- unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
- assert isinstance(unregister_tools_list_response, list)
- assert not unregister_tools_list_response
+ with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
+ llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py
index b5db6854a..2a30fd0b8 100644
--- a/tests/unit/distribution/routers/test_routing_tables.py
+++ b/tests/unit/distribution/routers/test_routing_tables.py
@@ -15,7 +15,7 @@ from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataS
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield
-from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter
+from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
@@ -101,11 +101,11 @@ class ToolGroupsImpl(Impl):
def __init__(self):
super().__init__(Api.tool_runtime)
- async def register_tool(self, tool):
- return tool
+ async def register_toolgroup(self, toolgroup: ToolGroup):
+ return toolgroup
- async def unregister_tool(self, tool_name: str):
- return tool_name
+ async def unregister_toolgroup(self, toolgroup_id: str):
+ return toolgroup_id
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
return ListToolDefsResponse(