diff --git a/tests/integration/tool_runtime/test_registration.py b/tests/integration/tool_runtime/test_registration.py index 8f59ee6b5..e04b56652 100644 --- a/tests/integration/tool_runtime/test_registration.py +++ b/tests/integration/tool_runtime/test_registration.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import socket import threading import time @@ -48,8 +49,15 @@ def mcp_server(): ], ) + def get_open_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + port = get_open_port() + def run_server(): - uvicorn.run(app, host="0.0.0.0", port=8000) + uvicorn.run(app, host="0.0.0.0", port=port) # Start the server in a new thread server_thread = threading.Thread(target=run_server, daemon=True) @@ -61,54 +69,56 @@ def mcp_server(): while time.time() - start_time < timeout: try: - response = httpx.get("http://localhost:8000/sse") + response = httpx.get(f"http://localhost:{port}/sse") if response.status_code == 200: break except (httpx.RequestError, httpx.HTTPStatusError): pass time.sleep(0.1) - yield + yield port -def test_register_and_unregister_toolgroup(client_with_models, mcp_server): +def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server): """ Integration test for registering and unregistering a toolgroup using the ToolGroups API. """ + port = mcp_server test_toolgroup_id = "remote::web-fetch" provider_id = "model-context-protocol" # Cleanup before running the test - test_toolgroup = client_with_models.toolgroups.get(toolgroup_id=test_toolgroup_id) - if test_toolgroup is not None: - client_with_models.toolgroups.unregister(toolgroup_id=test_toolgroup_id) + toolgroups = llama_stack_client.toolgroups.list() + for toolgroup in toolgroups: + if toolgroup.identifier == test_toolgroup_id: + llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id) # Register the toolgroup - client_with_models.toolgroups.register( + llama_stack_client.toolgroups.register( toolgroup_id=test_toolgroup_id, provider_id=provider_id, - mcp_endpoint=URL(uri="http://localhost:8000/sse"), + mcp_endpoint=URL(uri=f"http://localhost:{port}/sse"), ) # Verify registration - registered_toolgroup = client_with_models.toolgroups.get(toolgroup_id=test_toolgroup_id) + registered_toolgroup = llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) assert registered_toolgroup is not None assert registered_toolgroup.identifier == test_toolgroup_id assert registered_toolgroup.provider_id == provider_id # Verify tools listing - tools_list_response = client_with_models.tools.list(toolgroup_id=test_toolgroup_id) + tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) assert isinstance(tools_list_response, list) assert tools_list_response # Unregister the toolgroup - client_with_models.toolgroups.unregister(toolgroup_id=test_toolgroup_id) + llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id) - # Verify unregistration - unregistered_toolgroup = client_with_models.toolgroups.get(toolgroup_id=test_toolgroup_id) - assert unregistered_toolgroup is None + # Verify it is unregistered + with pytest.raises(ValueError, match=f"Tool group '{test_toolgroup_id}' not found"): + llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) # Verify tools are also unregistered - unregister_tools_list_response = client_with_models.tools.list(toolgroup_id=test_toolgroup_id) + unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) assert isinstance(unregister_tools_list_response, list) assert not unregister_tools_list_response