# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import socket import threading import time import httpx import mcp.types as types import pytest import uvicorn from llama_stack_client.types.shared_params.url import URL from mcp.server.fastmcp import Context, FastMCP from mcp.server.sse import SseServerTransport from starlette.applications import Starlette from starlette.routing import Mount, Route @pytest.fixture(scope="module") def mcp_server(): server = FastMCP("FastMCP Test Server") @server.tool() async def fetch(url: str, ctx: Context) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: headers = {"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"} async with httpx.AsyncClient(follow_redirects=True, headers=headers) as client: response = await client.get(url) response.raise_for_status() return [types.TextContent(type="text", text=response.text)] sse = SseServerTransport("/messages/") async def handle_sse(request): async with sse.connect_sse(request.scope, request.receive, request._send) as streams: await server._mcp_server.run( streams[0], streams[1], server._mcp_server.create_initialization_options(), ) app = Starlette( debug=True, routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ], ) def get_open_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.bind(("", 0)) return sock.getsockname()[1] port = get_open_port() def run_server(): uvicorn.run(app, host="0.0.0.0", port=port) # Start the server in a new thread server_thread = threading.Thread(target=run_server, daemon=True) server_thread.start() # Polling until the server is ready timeout = 10 start_time = time.time() while time.time() - start_time < timeout: try: response = httpx.get(f"http://localhost:{port}/sse") if response.status_code == 200: break except (httpx.RequestError, httpx.HTTPStatusError): pass time.sleep(0.1) yield port def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server): """ Integration test for registering and unregistering a toolgroup using the ToolGroups API. """ port = mcp_server test_toolgroup_id = "remote::web-fetch" provider_id = "model-context-protocol" # Cleanup before running the test toolgroups = llama_stack_client.toolgroups.list() for toolgroup in toolgroups: if toolgroup.identifier == test_toolgroup_id: llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id) # Register the toolgroup llama_stack_client.toolgroups.register( toolgroup_id=test_toolgroup_id, provider_id=provider_id, mcp_endpoint=URL(uri=f"http://localhost:{port}/sse"), ) # Verify registration registered_toolgroup = llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) assert registered_toolgroup is not None assert registered_toolgroup.identifier == test_toolgroup_id assert registered_toolgroup.provider_id == provider_id # Verify tools listing tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) assert isinstance(tools_list_response, list) assert tools_list_response # Unregister the toolgroup llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id) # Verify it is unregistered with pytest.raises(ValueError, match=f"Tool group '{test_toolgroup_id}' not found"): llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) # Verify tools are also unregistered unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) assert isinstance(unregister_tools_list_response, list) assert not unregister_tools_list_response