diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 99ae1c038..043e9467e 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9555,9 +9555,6 @@ "toolgroup_id": { "type": "string" }, - "tool_host": { - "$ref": "#/components/schemas/ToolHost" - }, "description": { "type": "string" }, @@ -9599,21 +9596,11 @@ "provider_id", "type", "toolgroup_id", - "tool_host", "description", "parameters" ], "title": "Tool" }, - "ToolHost": { - "type": "string", - "enum": [ - "distribution", - "client", - "model_context_protocol" - ], - "title": "ToolHost" - }, "ToolGroup": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 4e4f09eb0..c7ec8db5f 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6713,8 +6713,6 @@ components: default: tool toolgroup_id: type: string - tool_host: - $ref: '#/components/schemas/ToolHost' description: type: string parameters: @@ -6737,17 +6735,9 @@ components: - provider_id - type - toolgroup_id - - tool_host - description - parameters title: Tool - ToolHost: - type: string - enum: - - distribution - - client - - model_context_protocol - title: ToolHost ToolGroup: type: object properties: diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 29649495c..0c8d47edf 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -27,18 +27,10 @@ class ToolParameter(BaseModel): default: Any | None = None -@json_schema_type -class ToolHost(Enum): - distribution = "distribution" - client = "client" - model_context_protocol = "model_context_protocol" - - @json_schema_type class Tool(Resource): type: Literal[ResourceType.tool] = ResourceType.tool toolgroup_id: str - tool_host: ToolHost description: str parameters: list[ToolParameter] metadata: dict[str, Any] | None = None diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 8b846d051..b7c7cb87f 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import ( RemoteProviderSpec, ScoringFunctionsProtocolPrivate, ShieldsProtocolPrivate, - ToolsProtocolPrivate, + ToolGroupsProtocolPrivate, VectorDBsProtocolPrivate, ) @@ -93,7 +93,7 @@ def api_protocol_map_for_compliance_check() -> dict[Api, Any]: def additional_protocols_map() -> dict[Api, Any]: return { Api.inference: (ModelsProtocolPrivate, Models, Api.models), - Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups), + Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups), Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), diff --git a/llama_stack/distribution/routers/tool_runtime.py b/llama_stack/distribution/routers/tool_runtime.py index 2d4734a2e..285843dbc 100644 --- a/llama_stack/distribution/routers/tool_runtime.py +++ b/llama_stack/distribution/routers/tool_runtime.py @@ -11,7 +11,7 @@ from llama_stack.apis.common.content_types import ( InterleavedContent, ) from llama_stack.apis.tools import ( - ListToolDefsResponse, + ListToolsResponse, RAGDocument, RAGQueryConfig, RAGQueryResult, @@ -19,7 +19,8 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.log import get_logger -from llama_stack.providers.datatypes import RoutingTable + +from ..routing_tables.toolgroups import ToolGroupsRoutingTable logger = get_logger(name=__name__, category="core") @@ -28,7 +29,7 @@ class ToolRuntimeRouter(ToolRuntime): class RagToolImpl(RAGToolRuntime): def __init__( self, - routing_table: RoutingTable, + routing_table: ToolGroupsRoutingTable, ) -> None: logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") self.routing_table = routing_table @@ -59,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime): def __init__( self, - routing_table: RoutingTable, + routing_table: ToolGroupsRoutingTable, ) -> None: logger.debug("Initializing ToolRuntimeRouter") self.routing_table = routing_table @@ -86,6 +87,6 @@ class ToolRuntimeRouter(ToolRuntime): async def list_runtime_tools( self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None - ) -> ListToolDefsResponse: + ) -> ListToolsResponse: logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") - return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) + return await self.routing_table.list_tools(tool_group_id) diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index 95a92a5ba..8ec87ca50 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -46,7 +46,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable elif api == Api.eval: return await p.register_benchmark(obj) elif api == Api.tool_runtime: - return await p.register_tool(obj) + return await p.register_toolgroup(obj) else: raise ValueError(f"Unknown API {api} for registering object with provider") @@ -60,7 +60,7 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: elif api == Api.datasetio: return await p.unregister_dataset(obj.identifier) elif api == Api.tool_runtime: - return await p.unregister_tool(obj.identifier) + return await p.unregister_toolgroup(obj.identifier) else: raise ValueError(f"Unregister not supported for {api}") @@ -136,7 +136,7 @@ class CommonRoutingTableImpl(RoutingTable): elif isinstance(self, BenchmarksRoutingTable): return ("Eval", "benchmark") elif isinstance(self, ToolGroupsRoutingTable): - return ("Tools", "tool") + return ("ToolGroups", "tool_group") else: raise ValueError("Unknown routing table type") diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py index cb73dc7c2..3f103ed22 100644 --- a/llama_stack/distribution/routing_tables/toolgroups.py +++ b/llama_stack/distribution/routing_tables/toolgroups.py @@ -7,11 +7,8 @@ from typing import Any from llama_stack.apis.common.content_types import URL -from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups, ToolHost -from llama_stack.distribution.datatypes import ( - ToolGroupWithACL, - ToolWithACL, -) +from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups +from llama_stack.distribution.datatypes import ToolGroupWithACL from llama_stack.log import get_logger from .common import CommonRoutingTableImpl @@ -20,11 +17,51 @@ logger = get_logger(name=__name__, category="core") class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): + toolgroups_to_tools: dict[str, list[Tool]] = {} + tool_to_toolgroup: dict[str, str] = {} + + # overridden + def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: + # we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id + # TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while? + if routing_key in self.tool_to_toolgroup: + routing_key = self.tool_to_toolgroup[routing_key] + return super().get_provider_impl(routing_key, provider_id) + async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: - tools = await self.get_all_with_type("tool") if toolgroup_id: - tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id] - return ListToolsResponse(data=tools) + toolgroups = [await self.get_tool_group(toolgroup_id)] + else: + toolgroups = await self.get_all_with_type("tool_group") + + all_tools = [] + for toolgroup in toolgroups: + group_id = toolgroup.identifier + if group_id not in self.toolgroups_to_tools: + provider_impl = super().get_provider_impl(group_id, toolgroup.provider_id) + tooldefs_response = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint) + + # TODO: kill this Tool vs ToolDef distinction + tooldefs = tooldefs_response.data + tools = [] + for t in tooldefs: + tools.append( + Tool( + identifier=t.name, + toolgroup_id=group_id, + description=t.description or "", + parameters=t.parameters or [], + metadata=t.metadata, + provider_id=toolgroup.provider_id, + ) + ) + + self.toolgroups_to_tools[group_id] = tools + for tool in tools: + self.tool_to_toolgroup[tool.identifier] = group_id + all_tools.extend(self.toolgroups_to_tools[group_id]) + + return ListToolsResponse(data=all_tools) async def list_tool_groups(self) -> ListToolGroupsResponse: return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) @@ -36,7 +73,13 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): return tool_group async def get_tool(self, tool_name: str) -> Tool: - return await self.get_object_by_identifier("tool", tool_name) + if tool_name in self.tool_to_toolgroup: + toolgroup_id = self.tool_to_toolgroup[tool_name] + tools = self.toolgroups_to_tools[toolgroup_id] + for tool in tools: + if tool.identifier == tool_name: + return tool + raise ValueError(f"Tool '{tool_name}' not found") async def register_tool_group( self, @@ -45,53 +88,20 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): mcp_endpoint: URL | None = None, args: dict[str, Any] | None = None, ) -> None: - tools = [] - tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) - tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution - - for tool_def in tool_defs.data: - tools.append( - ToolWithACL( - identifier=tool_def.name, - toolgroup_id=toolgroup_id, - description=tool_def.description or "", - parameters=tool_def.parameters or [], - provider_id=provider_id, - provider_resource_id=tool_def.name, - metadata=tool_def.metadata, - tool_host=tool_host, - ) - ) - for tool in tools: - existing_tool = await self.get_tool(tool.identifier) - # Compare existing and new object if one exists - if existing_tool: - existing_dict = existing_tool.model_dump() - new_dict = tool.model_dump() - - if existing_dict != new_dict: - raise ValueError( - f"Object {tool.identifier} already exists in registry. Please use a different identifier." - ) - await self.register_object(tool) - - await self.dist_registry.register( - ToolGroupWithACL( - identifier=toolgroup_id, - provider_id=provider_id, - provider_resource_id=toolgroup_id, - mcp_endpoint=mcp_endpoint, - args=args, - ) + toolgroup = ToolGroupWithACL( + identifier=toolgroup_id, + provider_id=provider_id, + provider_resource_id=toolgroup_id, + mcp_endpoint=mcp_endpoint, + args=args, ) + await self.register_object(toolgroup) + return toolgroup async def unregister_toolgroup(self, toolgroup_id: str) -> None: tool_group = await self.get_tool_group(toolgroup_id) if tool_group is None: raise ValueError(f"Tool group {toolgroup_id} not found") - tools = await self.list_tools(toolgroup_id) - for tool in getattr(tools, "data", []): - await self.unregister_object(tool) await self.unregister_object(tool_group) async def shutdown(self) -> None: diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index a6b400136..0e84854c2 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -36,7 +36,7 @@ class DistributionRegistry(Protocol): REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v8" +KEY_VERSION = "v9" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 3e9806f23..60b05545b 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -16,7 +16,7 @@ from llama_stack.apis.datatypes import Api from llama_stack.apis.models import Model from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.shields import Shield -from llama_stack.apis.tools import Tool +from llama_stack.apis.tools import ToolGroup from llama_stack.apis.vector_dbs import VectorDB from llama_stack.schema_utils import json_schema_type @@ -74,10 +74,10 @@ class BenchmarksProtocolPrivate(Protocol): async def register_benchmark(self, benchmark: Benchmark) -> None: ... -class ToolsProtocolPrivate(Protocol): - async def register_tool(self, tool: Tool) -> None: ... +class ToolGroupsProtocolPrivate(Protocol): + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ... - async def unregister_tool(self, tool_id: str) -> None: ... + async def unregister_toolgroup(self, toolgroup_id: str) -> None: ... @json_schema_type diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index fe16c76b8..c2d264c91 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -25,14 +25,14 @@ from llama_stack.apis.tools import ( RAGQueryConfig, RAGQueryResult, RAGToolRuntime, - Tool, ToolDef, + ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.memory.vector_store import ( content_from_doc, @@ -49,7 +49,7 @@ def make_random_string(length: int = 8): return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) -class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): +class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime): def __init__( self, config: RagToolRuntimeConfig, @@ -66,10 +66,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): async def shutdown(self): pass - async def register_tool(self, tool: Tool) -> None: + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: pass - async def unregister_tool(self, tool_id: str) -> None: + async def unregister_toolgroup(self, toolgroup_id: str) -> None: return async def insert( diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py index 18bec463f..7e82cb6d4 100644 --- a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +++ b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py @@ -12,19 +12,19 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( ListToolDefsResponse, - Tool, ToolDef, + ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from .config import BingSearchToolConfig -class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: BingSearchToolConfig): self.config = config self.url = "https://api.bing.microsoft.com/v7.0/search" @@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP async def initialize(self): pass - async def register_tool(self, tool: Tool) -> None: + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: pass - async def unregister_tool(self, tool_id: str) -> None: + async def unregister_toolgroup(self, toolgroup_id: str) -> None: return def _get_api_key(self) -> str: diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index 355cb98b6..b96b9e59c 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -11,30 +11,30 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( ListToolDefsResponse, - Tool, ToolDef, + ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.models.llama.datatypes import BuiltinTool -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from .config import BraveSearchToolConfig -class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: BraveSearchToolConfig): self.config = config async def initialize(self): pass - async def register_tool(self, tool: Tool) -> None: + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: pass - async def unregister_tool(self, tool_id: str) -> None: + async def unregister_toolgroup(self, toolgroup_id: str) -> None: return def _get_api_key(self) -> str: diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 3f0b9a188..9603bf97e 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -11,12 +11,13 @@ from llama_stack.apis.common.content_types import URL from llama_stack.apis.datatypes import Api from llama_stack.apis.tools import ( ListToolDefsResponse, + ToolGroup, ToolInvocationResult, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools from .config import MCPProviderConfig @@ -24,13 +25,19 @@ from .config import MCPProviderConfig logger = get_logger(__name__, category="tools") -class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]): self.config = config async def initialize(self): pass + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: + pass + + async def unregister_toolgroup(self, toolgroup_id: str) -> None: + return + async def list_runtime_tools( self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None ) -> ListToolDefsResponse: diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index 9d6fcd951..1fe91fd7f 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -12,29 +12,29 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( ListToolDefsResponse, - Tool, ToolDef, + ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from .config import TavilySearchToolConfig -class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: TavilySearchToolConfig): self.config = config async def initialize(self): pass - async def register_tool(self, tool: Tool) -> None: + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: pass - async def unregister_tool(self, tool_id: str) -> None: + async def unregister_toolgroup(self, toolgroup_id: str) -> None: return def _get_api_key(self) -> str: diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py index a3724e4b4..6e1d0f61d 100644 --- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -12,19 +12,19 @@ import httpx from llama_stack.apis.common.content_types import URL from llama_stack.apis.tools import ( ListToolDefsResponse, - Tool, ToolDef, + ToolGroup, ToolInvocationResult, ToolParameter, ToolRuntime, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from .config import WolframAlphaToolConfig -class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): +class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: WolframAlphaToolConfig): self.config = config self.url = "https://api.wolframalpha.com/v2/query" @@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques async def initialize(self): pass - async def register_tool(self, tool: Tool) -> None: + async def register_toolgroup(self, toolgroup: ToolGroup) -> None: pass - async def unregister_tool(self, tool_id: str) -> None: + async def unregister_toolgroup(self, toolgroup_id: str) -> None: return def _get_api_key(self) -> str: diff --git a/tests/integration/tool_runtime/test_builtin_tools.py b/tests/integration/tool_runtime/test_builtin_tools.py index 9edf3afa0..1acf06719 100644 --- a/tests/integration/tool_runtime/test_builtin_tools.py +++ b/tests/integration/tool_runtime/test_builtin_tools.py @@ -25,10 +25,12 @@ def test_web_search_tool(llama_stack_client, sample_search_query): if "TAVILY_SEARCH_API_KEY" not in os.environ: pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") + tools = llama_stack_client.tool_runtime.list_tools() + assert any(tool.identifier == "web_search" for tool in tools) + response = llama_stack_client.tool_runtime.invoke_tool( tool_name="web_search", kwargs={"query": sample_search_query} ) - # Verify the response assert response.content is not None assert len(response.content) > 0 @@ -49,11 +51,12 @@ def test_wolfram_alpha_tool(llama_stack_client, sample_wolfram_alpha_query): if "WOLFRAM_ALPHA_API_KEY" not in os.environ: pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test") + tools = llama_stack_client.tool_runtime.list_tools() + assert any(tool.identifier == "wolfram_alpha" for tool in tools) response = llama_stack_client.tool_runtime.invoke_tool( tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query} ) - print(response.content) assert response.content is not None assert len(response.content) > 0 assert isinstance(response.content, str) diff --git a/tests/integration/tool_runtime/test_mcp.py b/tests/integration/tool_runtime/test_mcp.py index dd8a6d823..28b2e43c1 100644 --- a/tests/integration/tool_runtime/test_mcp.py +++ b/tests/integration/tool_runtime/test_mcp.py @@ -31,13 +31,12 @@ def test_mcp_invocation(llama_stack_client, mcp_server): test_toolgroup_id = MCP_TOOLGROUP_ID uri = mcp_server["server_url"] - # registering itself should fail since it requires listing tools - with pytest.raises(Exception, match="Unauthorized"): - llama_stack_client.toolgroups.register( - toolgroup_id=test_toolgroup_id, - provider_id="model-context-protocol", - mcp_endpoint=dict(uri=uri), - ) + # registering should not raise an error anymore even if you don't specify the auth token + llama_stack_client.toolgroups.register( + toolgroup_id=test_toolgroup_id, + provider_id="model-context-protocol", + mcp_endpoint=dict(uri=uri), + ) provider_data = { "mcp_headers": { @@ -50,18 +49,9 @@ def test_mcp_invocation(llama_stack_client, mcp_server): "X-LlamaStack-Provider-Data": json.dumps(provider_data), } - try: - llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers) - except Exception as e: - # An error is OK since the toolgroup may not exist - print(f"Error unregistering toolgroup: {e}") + with pytest.raises(Exception, match="Unauthorized"): + llama_stack_client.tools.list() - llama_stack_client.toolgroups.register( - toolgroup_id=test_toolgroup_id, - provider_id="model-context-protocol", - mcp_endpoint=dict(uri=uri), - extra_headers=auth_headers, - ) response = llama_stack_client.tools.list( toolgroup_id=test_toolgroup_id, extra_headers=auth_headers, diff --git a/tests/integration/tool_runtime/test_registration.py b/tests/integration/tool_runtime/test_registration.py index b8cbd964a..0846f8c89 100644 --- a/tests/integration/tool_runtime/test_registration.py +++ b/tests/integration/tool_runtime/test_registration.py @@ -51,7 +51,5 @@ def test_register_and_unregister_toolgroup(llama_stack_client): with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"): llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) - # Verify tools are also unregistered - unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) - assert isinstance(unregister_tools_list_response, list) - assert not unregister_tools_list_response + with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"): + llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index b5db6854a..2a30fd0b8 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -15,7 +15,7 @@ from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataS from llama_stack.apis.datatypes import Api from llama_stack.apis.models.models import Model, ModelType from llama_stack.apis.shields.shields import Shield -from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter +from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter from llama_stack.apis.vector_dbs.vector_dbs import VectorDB from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable @@ -101,11 +101,11 @@ class ToolGroupsImpl(Impl): def __init__(self): super().__init__(Api.tool_runtime) - async def register_tool(self, tool): - return tool + async def register_toolgroup(self, toolgroup: ToolGroup): + return toolgroup - async def unregister_tool(self, tool_name: str): - return tool_name + async def unregister_toolgroup(self, toolgroup_id: str): + return toolgroup_id async def list_runtime_tools(self, toolgroup_id, mcp_endpoint): return ListToolDefsResponse(