fix(tools): do not index tools, only index toolgroups (#2261)

When registering a MCP endpoint, we cannot list tools (like we used to)
since the MCP endpoint may be behind an auth wall. Registration can
happen much sooner (via run.yaml).

Instead, we do listing only when the _user_ actually calls listing.
Furthermore, we cache the list in-memory in the server. Currently, the
cache is not invalidated -- we may want to periodically re-list for MCP
servers. Note that they must call `list_tools` before calling
`invoke_tool` -- we use this critically.

This will enable us to list MCP servers in run.yaml

## Test Plan

Existing tests, updated tests accordingly.
This commit is contained in:
Ashwin Bharambe 2025-05-25 13:27:52 -07:00 committed by GitHub
parent 5a422e236c
commit ce33d02443
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 131 additions and 153 deletions

View file

@ -25,10 +25,12 @@ def test_web_search_tool(llama_stack_client, sample_search_query):
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
tools = llama_stack_client.tool_runtime.list_tools()
assert any(tool.identifier == "web_search" for tool in tools)
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="web_search", kwargs={"query": sample_search_query}
)
# Verify the response
assert response.content is not None
assert len(response.content) > 0
@ -49,11 +51,12 @@ def test_wolfram_alpha_tool(llama_stack_client, sample_wolfram_alpha_query):
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
tools = llama_stack_client.tool_runtime.list_tools()
assert any(tool.identifier == "wolfram_alpha" for tool in tools)
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
)
print(response.content)
assert response.content is not None
assert len(response.content) > 0
assert isinstance(response.content, str)

View file

@ -31,13 +31,12 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
test_toolgroup_id = MCP_TOOLGROUP_ID
uri = mcp_server["server_url"]
# registering itself should fail since it requires listing tools
with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=uri),
)
# registering should not raise an error anymore even if you don't specify the auth token
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=uri),
)
provider_data = {
"mcp_headers": {
@ -50,18 +49,9 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
try:
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers)
except Exception as e:
# An error is OK since the toolgroup may not exist
print(f"Error unregistering toolgroup: {e}")
with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.tools.list()
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=uri),
extra_headers=auth_headers,
)
response = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers,

View file

@ -51,7 +51,5 @@ def test_register_and_unregister_toolgroup(llama_stack_client):
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
# Verify tools are also unregistered
unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
assert isinstance(unregister_tools_list_response, list)
assert not unregister_tools_list_response
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)

View file

@ -15,7 +15,7 @@ from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataS
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
@ -101,11 +101,11 @@ class ToolGroupsImpl(Impl):
def __init__(self):
super().__init__(Api.tool_runtime)
async def register_tool(self, tool):
return tool
async def register_toolgroup(self, toolgroup: ToolGroup):
return toolgroup
async def unregister_tool(self, tool_name: str):
return tool_name
async def unregister_toolgroup(self, toolgroup_id: str):
return toolgroup_id
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
return ListToolDefsResponse(