From ce33d024438920908f2dfe6e3447f4991ef6725b Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sun, 25 May 2025 13:27:52 -0700 Subject: [PATCH] fix(tools): do not index tools, only index toolgroups (#2261) When registering a MCP endpoint, we cannot list tools (like we used to) since the MCP endpoint may be behind an auth wall. Registration can happen much sooner (via run.yaml). Instead, we do listing only when the _user_ actually calls listing. Furthermore, we cache the list in-memory in the server. Currently, the cache is not invalidated -- we may want to periodically re-list for MCP servers. Note that they must call `list_tools` before calling `invoke_tool` -- we use this critically. This will enable us to list MCP servers in run.yaml ## Test Plan Existing tests, updated tests accordingly. --- docs/_static/llama-stack-spec.html | 13 --- docs/_static/llama-stack-spec.yaml | 10 -- llama_stack/apis/tools/tools.py | 8 -- llama_stack/distribution/resolver.py | 4 +- .../distribution/routers/tool_runtime.py | 13 ++- .../distribution/routing_tables/common.py | 6 +- .../distribution/routing_tables/toolgroups.py | 110 ++++++++++-------- llama_stack/distribution/store/registry.py | 2 +- llama_stack/providers/datatypes.py | 8 +- .../inline/tool_runtime/rag/memory.py | 10 +- .../tool_runtime/bing_search/bing_search.py | 10 +- .../tool_runtime/brave_search/brave_search.py | 10 +- .../model_context_protocol.py | 11 +- .../tavily_search/tavily_search.py | 10 +- .../wolfram_alpha/wolfram_alpha.py | 10 +- .../tool_runtime/test_builtin_tools.py | 7 +- tests/integration/tool_runtime/test_mcp.py | 26 ++--- .../tool_runtime/test_registration.py | 6 +- .../routers/test_routing_tables.py | 10 +- 19 files changed, 131 insertions(+), 153 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 99ae1c038..043e9467e 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9555,9 +9555,6 @@ "toolgroup_id": { "type": "string" }, - "tool_host": { - "$ref": "#/components/schemas/ToolHost" - }, "description": { "type": "string" }, @@ -9599,21 +9596,11 @@ "provider_id", "type", "toolgroup_id", - "tool_host", "description", "parameters" ], "title": "Tool" }, - "ToolHost": { - "type": "string", - "enum": [ - "distribution", - "client", - "model_context_protocol" - ], - "title": "ToolHost" - }, "ToolGroup": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 4e4f09eb0..c7ec8db5f 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6713,8 +6713,6 @@ components: default: tool toolgroup_id: type: string - tool_host: - $ref: '#/components/schemas/ToolHost' description: type: string parameters: @@ -6737,17 +6735,9 @@ components: - provider_id - type - toolgroup_id - - tool_host - description - parameters title: Tool - ToolHost: - type: string - enum: - - distribution - - client - - model_context_protocol - title: ToolHost ToolGroup: type: object properties: diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 29649495c..0c8d47edf 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -27,18 +27,10 @@ class ToolParameter(BaseModel): default: Any | None = None -@json_schema_type -class ToolHost(Enum): - distribution = "distribution" - client = "client" - model_context_protocol = "model_context_protocol" - - @json_schema_type class Tool(Resource): type: Literal[ResourceType.tool] = ResourceType.tool toolgroup_id: str - tool_host: ToolHost description: str parameters: list[ToolParameter] metadata: dict[str, Any] | None = None diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 8b846d051..b7c7cb87f 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import ( RemoteProviderSpec, ScoringFunctionsProtocolPrivate, ShieldsProtocolPrivate, - ToolsProtocolPrivate, + ToolGroupsProtocolPrivate, VectorDBsProtocolPrivate, ) @@ -93,7 +93,7 @@ def api_protocol_map_for_compliance_check() -> dict[Api, Any]: def additional_protocols_map() -> dict[Api, Any]: return { Api.inference: (ModelsProtocolPrivate, Models, Api.models), - Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups), + Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups), Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), diff --git a/llama_stack/distribution/routers/tool_runtime.py b/llama_stack/distribution/routers/tool_runtime.py index 2d4734a2e..285843dbc 100644 --- a/llama_stack/distribution/routers/tool_runtime.py +++ b/llama_stack/distribution/routers/tool_runtime.py @@ -11,7 +11,7 @@ from llama_stack.apis.common.content_types import ( InterleavedContent, ) from llama_stack.apis.tools import ( - ListToolDefsResponse, + ListToolsResponse, RAGDocument, RAGQueryConfig, RAGQueryResult, @@ -19,7 +19,8 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.log import get_logger -from llama_stack.providers.datatypes import RoutingTable + +from ..routing_tables.toolgroups import ToolGroupsRoutingTable logger = get_logger(name=__name__, category="core") @@ -28,7 +29,7 @@ class ToolRuntimeRouter(ToolRuntime): class RagToolImpl(RAGToolRuntime): def __init__( self, - routing_table: RoutingTable, + routing_table: ToolGroupsRoutingTable, ) -> None: logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") self.routing_table = routing_table @@ -59,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime): def __init__( self, - routing_table: RoutingTable, + routing_table: ToolGroupsRoutingTable, ) -> None: logger.debug("Initializing ToolRuntimeRouter") self.routing_table = routing_table @@ -86,6 +87,6 @@ class ToolRuntimeRouter(ToolRuntime): async def list_runtime_tools( self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None - ) -> ListToolDefsResponse: + ) -> ListToolsResponse: logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") - return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) + return await self.routing_table.list_tools(tool_group_id) diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index 95a92a5ba..8ec87ca50 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -46,7 +46,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable elif api == Api.eval: return await p.register_benchmark(obj) elif api == Api.tool_runtime: - return await p.register_tool(obj) + return await p.register_toolgroup(obj) else: raise ValueError(f"Unknown API {api} for registering object with provider") @@ -60,7 +60,7 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: elif api == Api.datasetio: return await p.unregister_dataset(obj.identifier) elif api == Api.tool_runtime: - return await p.unregister_tool(obj.identifier) + return await p.unregister_toolgroup(obj.identifier) else: raise ValueError(f"Unregister not supported for {api}") @@ -136,7 +136,7 @@ class CommonRoutingTableImpl(RoutingTable): elif isinstance(self, BenchmarksRoutingTable): return ("Eval", "benchmark") elif isinstance(self, ToolGroupsRoutingTable): - return ("Tools", "tool") + return ("ToolGroups", "tool_group") else: raise ValueError("Unknown routing table type") diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py index cb73dc7c2..3f103ed22 100644 --- a/llama_stack/distribution/routing_tables/toolgroups.py +++ b/llama_stack/distribution/routing_tables/toolgroups.py @@ -7,11 +7,8 @@ from typing import Any from llama_stack.apis.common.content_types import URL -from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups, ToolHost -from llama_stack.distribution.datatypes import ( - ToolGroupWithACL, - ToolWithACL, -) +from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups +from llama_stack.distribution.datatypes import ToolGroupWithACL from llama_stack.log import get_logger from .common import CommonRoutingTableImpl @@ -20,11 +17,51 @@ logger = get_logger(name=__name__, category="core") class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): + toolgroups_to_tools: dict[str, list[Tool]] = {} + tool_to_toolgroup: dict[str, str] = {} + + # overridden + def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: + # we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id + # TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while? + if routing_key in self.tool_to_toolgroup: + routing_key = self.tool_to_toolgroup[routing_key] + return super().get_provider_impl(routing_key, provider_id) + async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: - tools = await self.get_all_with_type("tool") if toolgroup_id: - tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id] - return ListToolsResponse(data=tools) + toolgroups = [await self.get_tool_group(toolgroup_id)] + else: + toolgroups = await self.get_all_with_type("tool_group") + + all_tools = [] + for toolgroup in toolgroups: + group_id = toolgroup.identifier + if group_id not in self.toolgroups_to_tools: + provider_impl = super().get_provider_impl(group_id, toolgroup.provider_id) + tooldefs_response = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint) + + # TODO: kill this Tool vs ToolDef distinction + tooldefs = tooldefs_response.data + tools = [] + for t in tooldefs: + tools.append( + Tool( + identifier=t.name, + toolgroup_id=group_id, + description=t.description or "", + parameters=t.parameters or [], + metadata=t.metadata, + provider_id=toolgroup.provider_id, + ) + ) + + self.toolgroups_to_tools[group_id] = tools + for tool in tools: + self.tool_to_toolgroup[tool.identifier] = group_id + all_tools.extend(self.toolgroups_to_tools[group_id]) + + return ListToolsResponse(data=all_tools) async def list_tool_groups(self) -> ListToolGroupsResponse: return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) @@ -36,7 +73,13 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): return tool_group async def get_tool(self, tool_name: str) -> Tool: - return await self.get_object_by_identifier("tool", tool_name) + if tool_name in self.tool_to_toolgroup: + toolgroup_id = self.tool_to_toolgroup[tool_name] + tools = self.toolgroups_to_tools[toolgroup_id] + for tool in tools: + if tool.identifier == tool_name: + return tool + raise ValueError(f"Tool '{tool_name}' not found") async def register_tool_group( self, @@ -45,53 +88,20 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): mcp_endpoint: URL | None = None, args: dict[str, Any] | None = None, ) -> None: - tools = [] - tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) - tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution - - for tool_def in tool_defs.data: - tools.append( - ToolWithACL( - identifier=tool_def.name, - toolgroup_id=toolgroup_id, - description=tool_def.description or "", - parameters=tool_def.parameters or [], - provider_id=provider_id, - provider_resource_id=tool_def.name, - metadata=tool_def.metadata, - tool_host=tool_host, - ) - ) - for tool in tools: - existing_tool = await self.get_tool(tool.identifier) - # Compare existing and new object if one exists - if existing_tool: - existing_dict = existing_tool.model_dump() - new_dict = tool.model_dump() - - if existing_dict != new_dict: - raise ValueError( - f"Object {tool.identifier} already exists in registry. Please use a different identifier." - ) - await self.register_object(tool) - - await self.dist_registry.register( - ToolGroupWithACL( - identifier=toolgroup_id, - provider_id=provider_id, - provider_resource_id=toolgroup_id, - mcp_endpoint=mcp_endpoint, - args=args, - ) + toolgroup = ToolGroupWithACL( + identifier=toolgroup_id, + provider_id=provider_id, + provider_resource_id=toolgroup_id, + mcp_endpoint=mcp_endpoint, + args=args, ) + await self.register_object(toolgroup) + return toolgroup async def unregister_toolgroup(self, toolgroup_id: str) -> None: tool_group = await self.get_tool_group(toolgroup_id) if tool_group is None: raise ValueError(f"Tool group {toolgroup_id} not found") - tools = await self.list_tools(toolgroup_id) - for tool in getattr(tools, "data", []): - await self.unregister_object(tool) await self.unregister_object(tool_group) async def shutdown(self) -> None: diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index a6b400136..0e84854c2 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -36,7 +36,7 @@ class DistributionRegistry(Protocol): REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v8" +KEY_VERSION = "v9" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 3e9806f23..60b05545b 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -16,7 +16,7 @@ from llama_stack.apis.datatypes import Api from llama_stack.apis.models import Model from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.shields import Shield -from llama_stack.apis.tools import Tool +from llama_stack.apis.tools import ToolGroup from llama_stack.apis.vector_dbs import VectorDB from llama_stack.schema_utils import json_schema_type @@ -74,10 +74,10 @@ class BenchmarksProtocolPrivate(Protocol): async def register_benchmark(self, benchmark: Benchmark) -> None: ... -class ToolsProtocolPrivate(Protocol): - async def register_tool(self, tool: Tool) -> None: ... +class ToolGroupsProtocolPrivate(Protocol): + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ... - async def unregister_tool(self, tool_id: str) -> None: ... + async def unregister_toolgroup(self, toolgroup_id: str) -> None: ... @json_schema_type diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index fe16c76b8..c2d264c91 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -25,14 +25,14 @@ from llama_stack.apis.tools import ( RAGQueryConfig, RAGQueryResult, RAGToolRuntime, - Tool, ToolDef, + ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.memory.vector_store import ( content_from_doc, @@ -49,7 +49,7 @@ def make_random_string(length: int = 8): return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) -class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): +class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime): def __init__( self, config: RagToolRuntimeConfig, @@ -66,10 +66,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): async def shutdown(self): pass - async def register_tool(self, tool: Tool) -> None: + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: pass - async def unregister_tool(self, tool_id: str) -> None: + async def unregister_toolgroup(self, toolgroup_id: str) -> None: return async def insert( diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py index 18bec463f..7e82cb6d4 100644 --- a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +++ b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py @@ -12,19 +12,19 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( ListToolDefsResponse, - Tool, ToolDef, + ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from .config import BingSearchToolConfig -class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: BingSearchToolConfig): self.config = config self.url = "https://api.bing.microsoft.com/v7.0/search" @@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP async def initialize(self): pass - async def register_tool(self, tool: Tool) -> None: + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: pass - async def unregister_tool(self, tool_id: str) -> None: + async def unregister_toolgroup(self, toolgroup_id: str) -> None: return def _get_api_key(self) -> str: diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index 355cb98b6..b96b9e59c 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -11,30 +11,30 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( ListToolDefsResponse, - Tool, ToolDef, + ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.models.llama.datatypes import BuiltinTool -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from .config import BraveSearchToolConfig -class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: BraveSearchToolConfig): self.config = config async def initialize(self): pass - async def register_tool(self, tool: Tool) -> None: + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: pass - async def unregister_tool(self, tool_id: str) -> None: + async def unregister_toolgroup(self, toolgroup_id: str) -> None: return def _get_api_key(self) -> str: diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 3f0b9a188..9603bf97e 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -11,12 +11,13 @@ from llama_stack.apis.common.content_types import URL from llama_stack.apis.datatypes import Api from llama_stack.apis.tools import ( ListToolDefsResponse, + ToolGroup, ToolInvocationResult, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools from .config import MCPProviderConfig @@ -24,13 +25,19 @@ from .config import MCPProviderConfig logger = get_logger(__name__, category="tools") -class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]): self.config = config async def initialize(self): pass + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: + pass + + async def unregister_toolgroup(self, toolgroup_id: str) -> None: + return + async def list_runtime_tools( self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None ) -> ListToolDefsResponse: diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index 9d6fcd951..1fe91fd7f 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -12,29 +12,29 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( ListToolDefsResponse, - Tool, ToolDef, + ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from .config import TavilySearchToolConfig -class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: TavilySearchToolConfig): self.config = config async def initialize(self): pass - async def register_tool(self, tool: Tool) -> None: + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: pass - async def unregister_tool(self, tool_id: str) -> None: + async def unregister_toolgroup(self, toolgroup_id: str) -> None: return def _get_api_key(self) -> str: diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py index a3724e4b4..6e1d0f61d 100644 --- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -12,19 +12,19 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( ListToolDefsResponse, - Tool, ToolDef, + ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from .config import WolframAlphaToolConfig -class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: WolframAlphaToolConfig): self.config = config self.url = "https://api.wolframalpha.com/v2/query" @@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques async def initialize(self): pass - async def register_tool(self, tool: Tool) -> None: + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: pass - async def unregister_tool(self, tool_id: str) -> None: + async def unregister_toolgroup(self, toolgroup_id: str) -> None: return def _get_api_key(self) -> str: diff --git a/tests/integration/tool_runtime/test_builtin_tools.py b/tests/integration/tool_runtime/test_builtin_tools.py index 9edf3afa0..1acf06719 100644 --- a/tests/integration/tool_runtime/test_builtin_tools.py +++ b/tests/integration/tool_runtime/test_builtin_tools.py @@ -25,10 +25,12 @@ def test_web_search_tool(llama_stack_client, sample_search_query): if "TAVILY_SEARCH_API_KEY" not in os.environ: pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") + tools = llama_stack_client.tool_runtime.list_tools() + assert any(tool.identifier == "web_search" for tool in tools) + response = llama_stack_client.tool_runtime.invoke_tool( tool_name="web_search", kwargs={"query": sample_search_query} ) - # Verify the response assert response.content is not None assert len(response.content) > 0 @@ -49,11 +51,12 @@ def test_wolfram_alpha_tool(llama_stack_client, sample_wolfram_alpha_query): if "WOLFRAM_ALPHA_API_KEY" not in os.environ: pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test") + tools = llama_stack_client.tool_runtime.list_tools() + assert any(tool.identifier == "wolfram_alpha" for tool in tools) response = llama_stack_client.tool_runtime.invoke_tool( tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query} ) - print(response.content) assert response.content is not None assert len(response.content) > 0 assert isinstance(response.content, str) diff --git a/tests/integration/tool_runtime/test_mcp.py b/tests/integration/tool_runtime/test_mcp.py index dd8a6d823..28b2e43c1 100644 --- a/tests/integration/tool_runtime/test_mcp.py +++ b/tests/integration/tool_runtime/test_mcp.py @@ -31,13 +31,12 @@ def test_mcp_invocation(llama_stack_client, mcp_server): test_toolgroup_id = MCP_TOOLGROUP_ID uri = mcp_server["server_url"] - # registering itself should fail since it requires listing tools - with pytest.raises(Exception, match="Unauthorized"): - llama_stack_client.toolgroups.register( - toolgroup_id=test_toolgroup_id, - provider_id="model-context-protocol", - mcp_endpoint=dict(uri=uri), - ) + # registering should not raise an error anymore even if you don't specify the auth token + llama_stack_client.toolgroups.register( + toolgroup_id=test_toolgroup_id, + provider_id="model-context-protocol", + mcp_endpoint=dict(uri=uri), + ) provider_data = { "mcp_headers": { @@ -50,18 +49,9 @@ def test_mcp_invocation(llama_stack_client, mcp_server): "X-LlamaStack-Provider-Data": json.dumps(provider_data), } - try: - llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers) - except Exception as e: - # An error is OK since the toolgroup may not exist - print(f"Error unregistering toolgroup: {e}") + with pytest.raises(Exception, match="Unauthorized"): + llama_stack_client.tools.list() - llama_stack_client.toolgroups.register( - toolgroup_id=test_toolgroup_id, - provider_id="model-context-protocol", - mcp_endpoint=dict(uri=uri), - extra_headers=auth_headers, - ) response = llama_stack_client.tools.list( toolgroup_id=test_toolgroup_id, extra_headers=auth_headers, diff --git a/tests/integration/tool_runtime/test_registration.py b/tests/integration/tool_runtime/test_registration.py index b8cbd964a..0846f8c89 100644 --- a/tests/integration/tool_runtime/test_registration.py +++ b/tests/integration/tool_runtime/test_registration.py @@ -51,7 +51,5 @@ def test_register_and_unregister_toolgroup(llama_stack_client): with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"): llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) - # Verify tools are also unregistered - unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) - assert isinstance(unregister_tools_list_response, list) - assert not unregister_tools_list_response + with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"): + llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index b5db6854a..2a30fd0b8 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -15,7 +15,7 @@ from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataS from llama_stack.apis.datatypes import Api from llama_stack.apis.models.models import Model, ModelType from llama_stack.apis.shields.shields import Shield -from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter +from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter from llama_stack.apis.vector_dbs.vector_dbs import VectorDB from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable @@ -101,11 +101,11 @@ class ToolGroupsImpl(Impl): def __init__(self): super().__init__(Api.tool_runtime) - async def register_tool(self, tool): - return tool + async def register_toolgroup(self, toolgroup: ToolGroup): + return toolgroup - async def unregister_tool(self, tool_name: str): - return tool_name + async def unregister_toolgroup(self, toolgroup_id: str): + return toolgroup_id async def list_runtime_tools(self, toolgroup_id, mcp_endpoint): return ListToolDefsResponse(