several fixes

This commit is contained in:
Ashwin Bharambe 2025-05-25 10:35:48 -07:00
parent bf8a73e09a
commit cddc1f3524
15 changed files with 95 additions and 83 deletions

View file

@ -25,10 +25,12 @@ def test_web_search_tool(llama_stack_client, sample_search_query):
if "TAVILY_SEARCH_API_KEY" not in os.environ:
pytest.skip("TAVILY_SEARCH_API_KEY not set, skipping test")
tools = llama_stack_client.tool_runtime.list_tools()
assert any(tool.identifier == "web_search" for tool in tools)
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="web_search", kwargs={"query": sample_search_query}
)
# Verify the response
assert response.content is not None
assert len(response.content) > 0
@ -49,11 +51,12 @@ def test_wolfram_alpha_tool(llama_stack_client, sample_wolfram_alpha_query):
if "WOLFRAM_ALPHA_API_KEY" not in os.environ:
pytest.skip("WOLFRAM_ALPHA_API_KEY not set, skipping test")
tools = llama_stack_client.tool_runtime.list_tools()
assert any(tool.identifier == "wolfram_alpha" for tool in tools)
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
)
print(response.content)
assert response.content is not None
assert len(response.content) > 0
assert isinstance(response.content, str)

View file

@ -31,13 +31,12 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
test_toolgroup_id = MCP_TOOLGROUP_ID
uri = mcp_server["server_url"]
# registering itself should fail since it requires listing tools
with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=uri),
)
# registering should not raise an error anymore even if you don't specify the auth token
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=uri),
)
provider_data = {
"mcp_headers": {
@ -50,18 +49,9 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
try:
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers)
except Exception as e:
# An error is OK since the toolgroup may not exist
print(f"Error unregistering toolgroup: {e}")
with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.tools.list()
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=uri),
extra_headers=auth_headers,
)
response = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers,

View file

@ -51,7 +51,5 @@ def test_register_and_unregister_toolgroup(llama_stack_client):
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
# Verify tools are also unregistered
unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
assert isinstance(unregister_tools_list_response, list)
assert not unregister_tools_list_response
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)

View file

@ -15,7 +15,7 @@ from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataS
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models.models import Model, ModelType
from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
@ -101,11 +101,11 @@ class ToolGroupsImpl(Impl):
def __init__(self):
super().__init__(Api.tool_runtime)
async def register_tool(self, tool):
return tool
async def register_toolgroup(self, toolgroup: ToolGroup):
return toolgroup
async def unregister_tool(self, tool_name: str):
return tool_name
async def unregister_toolgroup(self, toolgroup_id: str):
return toolgroup_id
async def list_runtime_tools(self, toolgroup_id, mcp_endpoint):
return ListToolDefsResponse(