fix(tools): do not index tools, only index toolgroups (#2261)

When registering a MCP endpoint, we cannot list tools (like we used to)
since the MCP endpoint may be behind an auth wall. Registration can
happen much sooner (via run.yaml).

Instead, we do listing only when the _user_ actually calls listing.
Furthermore, we cache the list in-memory in the server. Currently, the
cache is not invalidated -- we may want to periodically re-list for MCP
servers. Note that they must call `list_tools` before calling
`invoke_tool` -- we use this critically.

This will enable us to list MCP servers in run.yaml

## Test Plan

Existing tests, updated tests accordingly.
This commit is contained in:
Ashwin Bharambe 2025-05-25 13:27:52 -07:00 committed by GitHub
parent 5a422e236c
commit ce33d02443
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 131 additions and 153 deletions

View file

@ -9555,9 +9555,6 @@
"toolgroup_id": { "toolgroup_id": {
"type": "string" "type": "string"
}, },
"tool_host": {
"$ref": "#/components/schemas/ToolHost"
},
"description": { "description": {
"type": "string" "type": "string"
}, },
@ -9599,21 +9596,11 @@
"provider_id", "provider_id",
"type", "type",
"toolgroup_id", "toolgroup_id",
"tool_host",
"description", "description",
"parameters" "parameters"
], ],
"title": "Tool" "title": "Tool"
}, },
"ToolHost": {
"type": "string",
"enum": [
"distribution",
"client",
"model_context_protocol"
],
"title": "ToolHost"
},
"ToolGroup": { "ToolGroup": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -6713,8 +6713,6 @@ components:
default: tool default: tool
toolgroup_id: toolgroup_id:
type: string type: string
tool_host:
$ref: '#/components/schemas/ToolHost'
description: description:
type: string type: string
parameters: parameters:
@ -6737,17 +6735,9 @@ components:
- provider_id - provider_id
- type - type
- toolgroup_id - toolgroup_id
- tool_host
- description - description
- parameters - parameters
title: Tool title: Tool
ToolHost:
type: string
enum:
- distribution
- client
- model_context_protocol
title: ToolHost
ToolGroup: ToolGroup:
type: object type: object
properties: properties:

View file

@ -27,18 +27,10 @@ class ToolParameter(BaseModel):
default: Any | None = None default: Any | None = None
@json_schema_type
class ToolHost(Enum):
distribution = "distribution"
client = "client"
model_context_protocol = "model_context_protocol"
@json_schema_type @json_schema_type
class Tool(Resource): class Tool(Resource):
type: Literal[ResourceType.tool] = ResourceType.tool type: Literal[ResourceType.tool] = ResourceType.tool
toolgroup_id: str toolgroup_id: str
tool_host: ToolHost
description: str description: str
parameters: list[ToolParameter] parameters: list[ToolParameter]
metadata: dict[str, Any] | None = None metadata: dict[str, Any] | None = None

View file

@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import (
RemoteProviderSpec, RemoteProviderSpec,
ScoringFunctionsProtocolPrivate, ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate, ShieldsProtocolPrivate,
ToolsProtocolPrivate, ToolGroupsProtocolPrivate,
VectorDBsProtocolPrivate, VectorDBsProtocolPrivate,
) )
@ -93,7 +93,7 @@ def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
def additional_protocols_map() -> dict[Api, Any]: def additional_protocols_map() -> dict[Api, Any]:
return { return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models), Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups), Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs), Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),

View file

@ -11,7 +11,7 @@ from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
) )
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolsResponse,
RAGDocument, RAGDocument,
RAGQueryConfig, RAGQueryConfig,
RAGQueryResult, RAGQueryResult,
@ -19,7 +19,8 @@ from llama_stack.apis.tools import (
ToolRuntime, ToolRuntime,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
@ -28,7 +29,7 @@ class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime): class RagToolImpl(RAGToolRuntime):
def __init__( def __init__(
self, self,
routing_table: RoutingTable, routing_table: ToolGroupsRoutingTable,
) -> None: ) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table self.routing_table = routing_table
@ -59,7 +60,7 @@ class ToolRuntimeRouter(ToolRuntime):
def __init__( def __init__(
self, self,
routing_table: RoutingTable, routing_table: ToolGroupsRoutingTable,
) -> None: ) -> None:
logger.debug("Initializing ToolRuntimeRouter") logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table self.routing_table = routing_table
@ -86,6 +87,6 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse: ) -> ListToolsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) return await self.routing_table.list_tools(tool_group_id)

View file

@ -46,7 +46,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
elif api == Api.eval: elif api == Api.eval:
return await p.register_benchmark(obj) return await p.register_benchmark(obj)
elif api == Api.tool_runtime: elif api == Api.tool_runtime:
return await p.register_tool(obj) return await p.register_toolgroup(obj)
else: else:
raise ValueError(f"Unknown API {api} for registering object with provider") raise ValueError(f"Unknown API {api} for registering object with provider")
@ -60,7 +60,7 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
elif api == Api.datasetio: elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier) return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime: elif api == Api.tool_runtime:
return await p.unregister_tool(obj.identifier) return await p.unregister_toolgroup(obj.identifier)
else: else:
raise ValueError(f"Unregister not supported for {api}") raise ValueError(f"Unregister not supported for {api}")
@ -136,7 +136,7 @@ class CommonRoutingTableImpl(RoutingTable):
elif isinstance(self, BenchmarksRoutingTable): elif isinstance(self, BenchmarksRoutingTable):
return ("Eval", "benchmark") return ("Eval", "benchmark")
elif isinstance(self, ToolGroupsRoutingTable): elif isinstance(self, ToolGroupsRoutingTable):
return ("Tools", "tool") return ("ToolGroups", "tool_group")
else: else:
raise ValueError("Unknown routing table type") raise ValueError("Unknown routing table type")

View file

@ -7,11 +7,8 @@
from typing import Any from typing import Any
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups, ToolHost from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import ToolGroupWithACL
ToolGroupWithACL,
ToolWithACL,
)
from llama_stack.log import get_logger from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
@ -20,11 +17,51 @@ logger = get_logger(name=__name__, category="core")
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
toolgroups_to_tools: dict[str, list[Tool]] = {}
tool_to_toolgroup: dict[str, str] = {}
# overridden
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
if routing_key in self.tool_to_toolgroup:
routing_key = self.tool_to_toolgroup[routing_key]
return super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
tools = await self.get_all_with_type("tool")
if toolgroup_id: if toolgroup_id:
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id] toolgroups = [await self.get_tool_group(toolgroup_id)]
return ListToolsResponse(data=tools) else:
toolgroups = await self.get_all_with_type("tool_group")
all_tools = []
for toolgroup in toolgroups:
group_id = toolgroup.identifier
if group_id not in self.toolgroups_to_tools:
provider_impl = super().get_provider_impl(group_id, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(group_id, toolgroup.mcp_endpoint)
# TODO: kill this Tool vs ToolDef distinction
tooldefs = tooldefs_response.data
tools = []
for t in tooldefs:
tools.append(
Tool(
identifier=t.name,
toolgroup_id=group_id,
description=t.description or "",
parameters=t.parameters or [],
metadata=t.metadata,
provider_id=toolgroup.provider_id,
)
)
self.toolgroups_to_tools[group_id] = tools
for tool in tools:
self.tool_to_toolgroup[tool.identifier] = group_id
all_tools.extend(self.toolgroups_to_tools[group_id])
return ListToolsResponse(data=all_tools)
async def list_tool_groups(self) -> ListToolGroupsResponse: async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
@ -36,7 +73,13 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return tool_group return tool_group
async def get_tool(self, tool_name: str) -> Tool: async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name) if tool_name in self.tool_to_toolgroup:
toolgroup_id = self.tool_to_toolgroup[tool_name]
tools = self.toolgroups_to_tools[toolgroup_id]
for tool in tools:
if tool.identifier == tool_name:
return tool
raise ValueError(f"Tool '{tool_name}' not found")
async def register_tool_group( async def register_tool_group(
self, self,
@ -45,53 +88,20 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
mcp_endpoint: URL | None = None, mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None, args: dict[str, Any] | None = None,
) -> None: ) -> None:
tools = [] toolgroup = ToolGroupWithACL(
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) identifier=toolgroup_id,
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution provider_id=provider_id,
provider_resource_id=toolgroup_id,
for tool_def in tool_defs.data: mcp_endpoint=mcp_endpoint,
tools.append( args=args,
ToolWithACL(
identifier=tool_def.name,
toolgroup_id=toolgroup_id,
description=tool_def.description or "",
parameters=tool_def.parameters or [],
provider_id=provider_id,
provider_resource_id=tool_def.name,
metadata=tool_def.metadata,
tool_host=tool_host,
)
)
for tool in tools:
existing_tool = await self.get_tool(tool.identifier)
# Compare existing and new object if one exists
if existing_tool:
existing_dict = existing_tool.model_dump()
new_dict = tool.model_dump()
if existing_dict != new_dict:
raise ValueError(
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
)
await self.register_object(tool)
await self.dist_registry.register(
ToolGroupWithACL(
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint,
args=args,
)
) )
await self.register_object(toolgroup)
return toolgroup
async def unregister_toolgroup(self, toolgroup_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
tool_group = await self.get_tool_group(toolgroup_id) tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None: if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found") raise ValueError(f"Tool group {toolgroup_id} not found")
tools = await self.list_tools(toolgroup_id)
for tool in getattr(tools, "data", []):
await self.unregister_object(tool)
await self.unregister_object(tool_group) await self.unregister_object(tool_group)
async def shutdown(self) -> None: async def shutdown(self) -> None:

View file

@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry" REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v8" KEY_VERSION = "v9"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -16,7 +16,7 @@ from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import Tool from llama_stack.apis.tools import ToolGroup
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -74,10 +74,10 @@ class BenchmarksProtocolPrivate(Protocol):
async def register_benchmark(self, benchmark: Benchmark) -> None: ... async def register_benchmark(self, benchmark: Benchmark) -> None: ...
class ToolsProtocolPrivate(Protocol): class ToolGroupsProtocolPrivate(Protocol):
async def register_tool(self, tool: Tool) -> None: ... async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ...
async def unregister_tool(self, tool_id: str) -> None: ... async def unregister_toolgroup(self, toolgroup_id: str) -> None: ...
@json_schema_type @json_schema_type

View file

@ -25,14 +25,14 @@ from llama_stack.apis.tools import (
RAGQueryConfig, RAGQueryConfig,
RAGQueryResult, RAGQueryResult,
RAGToolRuntime, RAGToolRuntime,
Tool,
ToolDef, ToolDef,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
content_from_doc, content_from_doc,
@ -49,7 +49,7 @@ def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__( def __init__(
self, self,
config: RagToolRuntimeConfig, config: RagToolRuntimeConfig,
@ -66,10 +66,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def shutdown(self): async def shutdown(self):
pass pass
async def register_tool(self, tool: Tool) -> None: async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return return
async def insert( async def insert(

View file

@ -12,19 +12,19 @@ import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
Tool,
ToolDef, ToolDef,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import BingSearchToolConfig from .config import BingSearchToolConfig
class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: BingSearchToolConfig): def __init__(self, config: BingSearchToolConfig):
self.config = config self.config = config
self.url = "https://api.bing.microsoft.com/v7.0/search" self.url = "https://api.bing.microsoft.com/v7.0/search"
@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool) -> None: async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return return
def _get_api_key(self) -> str: def _get_api_key(self) -> str:

View file

@ -11,30 +11,30 @@ import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
Tool,
ToolDef, ToolDef,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import BraveSearchToolConfig from .config import BraveSearchToolConfig
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: BraveSearchToolConfig): def __init__(self, config: BraveSearchToolConfig):
self.config = config self.config = config
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool) -> None: async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return return
def _get_api_key(self) -> str: def _get_api_key(self) -> str:

View file

@ -11,12 +11,13 @@ from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
from .config import MCPProviderConfig from .config import MCPProviderConfig
@ -24,13 +25,19 @@ from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools") logger = get_logger(__name__, category="tools")
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]): def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config self.config = config
async def initialize(self): async def initialize(self):
pass pass
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse: ) -> ListToolDefsResponse:

View file

@ -12,29 +12,29 @@ import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
Tool,
ToolDef, ToolDef,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import TavilySearchToolConfig from .config import TavilySearchToolConfig
class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: TavilySearchToolConfig): def __init__(self, config: TavilySearchToolConfig):
self.config = config self.config = config
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool) -> None: async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return return
def _get_api_key(self) -> str: def _get_api_key(self) -> str:

View file

@ -12,19 +12,19 @@ import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
Tool,
ToolDef, ToolDef,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import WolframAlphaToolConfig from .config import WolframAlphaToolConfig
class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: WolframAlphaToolConfig): def __init__(self, config: WolframAlphaToolConfig):
self.config = config self.config = config
self.url = "https://api.wolframalpha.com/v2/query" self.url = "https://api.wolframalpha.com/v2/query"
@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool) -> None: async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return return
def _get_api_key(self) -> str: def _get_api_key(self) -> str:

View file

@ -25,10 +25,12 @@ def test_web_search_tool(llama_stack_client, sample_search_query):
if "TAVILY_SEARCH_API_KEY" not in os.environ: if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test") pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
tools = llama_stack_client.tool_runtime.list_tools()
assert any(tool.identifier == "web_search" for tool in tools)
response = llama_stack_client.tool_runtime.invoke_tool( response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="web_search", kwargs={"query": sample_search_query} tool_name="web_search", kwargs={"query": sample_search_query}
) )
# Verify the response # Verify the response
assert response.content is not None assert response.content is not None
assert len(response.content) > 0 assert len(response.content) > 0
@ -49,11 +51,12 @@ def test_wolfram_alpha_tool(llama_stack_client, sample_wolfram_alpha_query):
if "WOLFRAM_ALPHA_API_KEY" not in os.environ: if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test") pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
tools = llama_stack_client.tool_runtime.list_tools()
assert any(tool.identifier == "wolfram_alpha" for tool in tools)
response = llama_stack_client.tool_runtime.invoke_tool( response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query} tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
) )
print(response.content)
assert response.content is not None assert response.content is not None
assert len(response.content) > 0 assert len(response.content) > 0
assert isinstance(response.content, str) assert isinstance(response.content, str)

View file

@ -31,13 +31,12 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
test_toolgroup_id = MCP_TOOLGROUP_ID test_toolgroup_id = MCP_TOOLGROUP_ID
uri = mcp_server["server_url"] uri = mcp_server["server_url"]
# registering itself should fail since it requires listing tools # registering should not raise an error anymore even if you don't specify the auth token
with pytest.raises(Exception, match="Unauthorized"): llama_stack_client.toolgroups.register(
llama_stack_client.toolgroups.register( toolgroup_id=test_toolgroup_id,
toolgroup_id=test_toolgroup_id, provider_id="model-context-protocol",
provider_id="model-context-protocol", mcp_endpoint=dict(uri=uri),
mcp_endpoint=dict(uri=uri), )
)
provider_data = { provider_data = {
"mcp_headers": { "mcp_headers": {
@ -50,18 +49,9 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
"X-LlamaStack-Provider-Data": json.dumps(provider_data), "X-LlamaStack-Provider-Data": json.dumps(provider_data),
} }
try: with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers) llama_stack_client.tools.list()
except Exception as e:
# An error is OK since the toolgroup may not exist
print(f"Error unregistering toolgroup: {e}")
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=uri),
extra_headers=auth_headers,
)
response = llama_stack_client.tools.list( response = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id, toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers, extra_headers=auth_headers,

View file

@ -51,7 +51,5 @@ def test_register_and_unregister_toolgroup(llama_stack_client):
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"): with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
# Verify tools are also unregistered with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
assert isinstance(unregister_tools_list_response, list)
assert not unregister_tools_list_response

View file

@ -15,7 +15,7 @@ from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataS
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
from llama_stack.apis.models.models import Model, ModelType from llama_stack.apis.models.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
@ -101,11 +101,11 @@ class ToolGroupsImpl(Impl):
def __init__(self): def __init__(self):
super().__init__(Api.tool_runtime) super().__init__(Api.tool_runtime)
async def register_tool(self, tool): async def register_toolgroup(self, toolgroup: ToolGroup):
return tool return toolgroup
async def unregister_tool(self, tool_name: str): async def unregister_toolgroup(self, toolgroup_id: str):
return tool_name return toolgroup_id
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint): async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
return ListToolDefsResponse( return ListToolDefsResponse(