forked from phoenix-oss/llama-stack-mirror
fix(tools): do not index tools, only index toolgroups (#2261)
When registering a MCP endpoint, we cannot list tools (like we used to) since the MCP endpoint may be behind an auth wall. Registration can happen much sooner (via run.yaml). Instead, we do listing only when the _user_ actually calls listing. Furthermore, we cache the list in-memory in the server. Currently, the cache is not invalidated -- we may want to periodically re-list for MCP servers. Note that they must call `list_tools` before calling `invoke_tool` -- we use this critically. This will enable us to list MCP servers in run.yaml ## Test Plan Existing tests, updated tests accordingly.
This commit is contained in:
parent
5a422e236c
commit
ce33d02443
19 changed files with 131 additions and 153 deletions
|
@ -25,10 +25,12 @@ def test_web_search_tool(llama_stack_client, sample_search_query):
|
|||
if "TAVILY_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
|
||||
|
||||
tools = llama_stack_client.tool_runtime.list_tools()
|
||||
assert any(tool.identifier == "web_search" for tool in tools)
|
||||
|
||||
response = llama_stack_client.tool_runtime.invoke_tool(
|
||||
tool_name="web_search", kwargs={"query": sample_search_query}
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response.content is not None
|
||||
assert len(response.content) > 0
|
||||
|
@ -49,11 +51,12 @@ def test_wolfram_alpha_tool(llama_stack_client, sample_wolfram_alpha_query):
|
|||
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
|
||||
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
|
||||
|
||||
tools = llama_stack_client.tool_runtime.list_tools()
|
||||
assert any(tool.identifier == "wolfram_alpha" for tool in tools)
|
||||
response = llama_stack_client.tool_runtime.invoke_tool(
|
||||
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
|
||||
)
|
||||
|
||||
print(response.content)
|
||||
assert response.content is not None
|
||||
assert len(response.content) > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
|
|
@ -31,13 +31,12 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
|
|||
test_toolgroup_id = MCP_TOOLGROUP_ID
|
||||
uri = mcp_server["server_url"]
|
||||
|
||||
# registering itself should fail since it requires listing tools
|
||||
with pytest.raises(Exception, match="Unauthorized"):
|
||||
llama_stack_client.toolgroups.register(
|
||||
toolgroup_id=test_toolgroup_id,
|
||||
provider_id="model-context-protocol",
|
||||
mcp_endpoint=dict(uri=uri),
|
||||
)
|
||||
# registering should not raise an error anymore even if you don't specify the auth token
|
||||
llama_stack_client.toolgroups.register(
|
||||
toolgroup_id=test_toolgroup_id,
|
||||
provider_id="model-context-protocol",
|
||||
mcp_endpoint=dict(uri=uri),
|
||||
)
|
||||
|
||||
provider_data = {
|
||||
"mcp_headers": {
|
||||
|
@ -50,18 +49,9 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
|
|||
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
|
||||
}
|
||||
|
||||
try:
|
||||
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers)
|
||||
except Exception as e:
|
||||
# An error is OK since the toolgroup may not exist
|
||||
print(f"Error unregistering toolgroup: {e}")
|
||||
with pytest.raises(Exception, match="Unauthorized"):
|
||||
llama_stack_client.tools.list()
|
||||
|
||||
llama_stack_client.toolgroups.register(
|
||||
toolgroup_id=test_toolgroup_id,
|
||||
provider_id="model-context-protocol",
|
||||
mcp_endpoint=dict(uri=uri),
|
||||
extra_headers=auth_headers,
|
||||
)
|
||||
response = llama_stack_client.tools.list(
|
||||
toolgroup_id=test_toolgroup_id,
|
||||
extra_headers=auth_headers,
|
||||
|
|
|
@ -51,7 +51,5 @@ def test_register_and_unregister_toolgroup(llama_stack_client):
|
|||
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
|
||||
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
|
||||
|
||||
# Verify tools are also unregistered
|
||||
unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
|
||||
assert isinstance(unregister_tools_list_response, list)
|
||||
assert not unregister_tools_list_response
|
||||
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
|
||||
llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
|
||||
|
|
|
@ -15,7 +15,7 @@ from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataS
|
|||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.models.models import Model, ModelType
|
||||
from llama_stack.apis.shields.shields import Shield
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
|
||||
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
||||
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
|
||||
|
@ -101,11 +101,11 @@ class ToolGroupsImpl(Impl):
|
|||
def __init__(self):
|
||||
super().__init__(Api.tool_runtime)
|
||||
|
||||
async def register_tool(self, tool):
|
||||
return tool
|
||||
async def register_toolgroup(self, toolgroup: ToolGroup):
|
||||
return toolgroup
|
||||
|
||||
async def unregister_tool(self, tool_name: str):
|
||||
return tool_name
|
||||
async def unregister_toolgroup(self, toolgroup_id: str):
|
||||
return toolgroup_id
|
||||
|
||||
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
|
||||
return ListToolDefsResponse(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue