From a5132b4857df89c2454d805d8b42dede3ade0b11 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 24 May 2025 12:02:28 -0700 Subject: [PATCH] more test fixes --- tests/common/mcp.py | 10 +- tests/integration/tool_runtime/test_mcp.py | 120 +++------------ .../tool_runtime/test_registration.py | 144 +++++------------- 3 files changed, 66 insertions(+), 208 deletions(-) diff --git a/tests/common/mcp.py b/tests/common/mcp.py index b66653ace..fd7040c6c 100644 --- a/tests/common/mcp.py +++ b/tests/common/mcp.py @@ -7,6 +7,11 @@ # we want the mcp server to be authenticated OR not, depends from contextlib import contextmanager +# Unfortunately the toolgroup id must be tied to the tool names because the registry +# indexes on both toolgroups and tools independently (and not jointly). That really +# needs to be fixed. +MCP_TOOLGROUP_ID = "mcp::localmcp" + @contextmanager def make_mcp_server(required_auth_token: str | None = None): @@ -22,7 +27,7 @@ def make_mcp_server(required_auth_token: str | None = None): from starlette.responses import Response from starlette.routing import Mount, Route - server = FastMCP("FastMCP Test Server") + server = FastMCP("FastMCP Test Server", log_level="WARNING") @server.tool() async def greet_everyone( @@ -84,7 +89,8 @@ def make_mcp_server(required_auth_token: str | None = None): port = get_open_port() - config = uvicorn.Config(app, host="0.0.0.0", port=port) + # make uvicorn logs be less verbose + config = uvicorn.Config(app, host="0.0.0.0", port=port, log_level="warning") server_instance = uvicorn.Server(config) app.state.uvicorn_server = server_instance diff --git a/tests/integration/tool_runtime/test_mcp.py b/tests/integration/tool_runtime/test_mcp.py index 9dd767b2f..dd8a6d823 100644 --- a/tests/integration/tool_runtime/test_mcp.py +++ b/tests/integration/tool_runtime/test_mcp.py @@ -5,120 +5,43 @@ # the root directory of this source tree. import json -import socket -import threading -import time -import httpx -import mcp.types as types import pytest -import uvicorn from llama_stack_client import Agent -from mcp.server.fastmcp import Context, FastMCP -from mcp.server.sse import SseServerTransport -from starlette.applications import Starlette -from starlette.exceptions import HTTPException -from starlette.responses import Response -from starlette.routing import Mount, Route from llama_stack import LlamaStackAsLibraryClient +from llama_stack.apis.models import ModelType from llama_stack.distribution.datatypes import AuthenticationRequiredError AUTH_TOKEN = "test-token" +from tests.common.mcp import MCP_TOOLGROUP_ID, make_mcp_server -@pytest.fixture(scope="module") + +@pytest.fixture(scope="function") def mcp_server(): - server = FastMCP("FastMCP Test Server") - - @server.tool() - async def greet_everyone( - url: str, ctx: Context - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: - return [types.TextContent(type="text", text="Hello, world!")] - - sse = SseServerTransport("/messages/") - - async def handle_sse(request): - auth_header = request.headers.get("Authorization") - auth_token = None - if auth_header and auth_header.startswith("Bearer "): - auth_token = auth_header.split(" ")[1] - - if auth_token != AUTH_TOKEN: - raise HTTPException(status_code=401, detail="Unauthorized") - - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await server._mcp_server.run( - streams[0], - streams[1], - server._mcp_server.create_initialization_options(), - ) - return Response() - - app = Starlette( - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ], - ) - - def get_open_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("", 0)) - return sock.getsockname()[1] - - port = get_open_port() - - config = uvicorn.Config(app, host="0.0.0.0", port=port) - server_instance = uvicorn.Server(config) - app.state.uvicorn_server = server_instance - - def run_server(): - server_instance.run() - - # Start the server in a new thread - server_thread = threading.Thread(target=run_server, daemon=True) - server_thread.start() - - # Polling until the server is ready - timeout = 10 - start_time = time.time() - - while time.time() - start_time < timeout: - try: - response = httpx.get(f"http://localhost:{port}/sse") - if response.status_code == 401: - break - except httpx.RequestError: - pass - time.sleep(0.1) - - yield port - - # Tell server to exit - server_instance.should_exit = True - server_thread.join(timeout=5) + with make_mcp_server(required_auth_token=AUTH_TOKEN) as mcp_server_info: + yield mcp_server_info def test_mcp_invocation(llama_stack_client, mcp_server): if not isinstance(llama_stack_client, LlamaStackAsLibraryClient): pytest.skip("The local MCP server only reliably reachable from library client.") - port = mcp_server - test_toolgroup_id = "remote::mcptest" + test_toolgroup_id = MCP_TOOLGROUP_ID + uri = mcp_server["server_url"] # registering itself should fail since it requires listing tools with pytest.raises(Exception, match="Unauthorized"): llama_stack_client.toolgroups.register( toolgroup_id=test_toolgroup_id, provider_id="model-context-protocol", - mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"), + mcp_endpoint=dict(uri=uri), ) provider_data = { "mcp_headers": { - f"http://localhost:{port}/sse": [ + uri: [ f"Authorization: Bearer {AUTH_TOKEN}", ], }, @@ -136,24 +59,18 @@ def test_mcp_invocation(llama_stack_client, mcp_server): llama_stack_client.toolgroups.register( toolgroup_id=test_toolgroup_id, provider_id="model-context-protocol", - mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"), + mcp_endpoint=dict(uri=uri), extra_headers=auth_headers, ) response = llama_stack_client.tools.list( toolgroup_id=test_toolgroup_id, extra_headers=auth_headers, ) - assert len(response) == 1 - assert response[0].identifier == "greet_everyone" - assert response[0].type == "tool" - assert len(response[0].parameters) == 1 - p = response[0].parameters[0] - assert p.name == "url" - assert p.parameter_type == "string" - assert p.required + assert len(response) == 2 + assert {t.identifier for t in response} == {"greet_everyone", "get_boiling_point"} response = llama_stack_client.tool_runtime.invoke_tool( - tool_name=response[0].identifier, + tool_name="greet_everyone", kwargs=dict(url="https://www.google.com"), extra_headers=auth_headers, ) @@ -162,7 +79,9 @@ def test_mcp_invocation(llama_stack_client, mcp_server): assert content[0].type == "text" assert content[0].text == "Hello, world!" - models = llama_stack_client.models.list() + models = [ + m for m in llama_stack_client.models.list() if m.model_type == ModelType.llm and "guard" not in m.identifier + ] model_id = models[0].identifier print(f"Using model: {model_id}") agent = Agent( @@ -177,7 +96,7 @@ def test_mcp_invocation(llama_stack_client, mcp_server): messages=[ { "role": "user", - "content": "Yo. Use tools.", + "content": "Say hi to the world. Use tools to do so.", } ], stream=False, @@ -199,7 +118,6 @@ def test_mcp_invocation(llama_stack_client, mcp_server): third = steps[2] assert third.step_type == "inference" - assert len(third.api_model_response.tool_calls) == 0 # when streaming, we currently don't check auth headers upfront and fail the request # early. but we should at least be generating a 401 later in the process. @@ -208,7 +126,7 @@ def test_mcp_invocation(llama_stack_client, mcp_server): messages=[ { "role": "user", - "content": "Yo. Use tools.", + "content": "What is the boiling point of polyjuice? Use tools to answer.", } ], stream=True, diff --git a/tests/integration/tool_runtime/test_registration.py b/tests/integration/tool_runtime/test_registration.py index b36237d05..b8cbd964a 100644 --- a/tests/integration/tool_runtime/test_registration.py +++ b/tests/integration/tool_runtime/test_registration.py @@ -4,120 +4,54 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import socket -import threading -import time - -import httpx -import mcp.types as types import pytest -import uvicorn -from mcp.server.fastmcp import Context, FastMCP -from mcp.server.sse import SseServerTransport -from starlette.applications import Starlette -from starlette.routing import Mount, Route + +from llama_stack import LlamaStackAsLibraryClient +from tests.common.mcp import MCP_TOOLGROUP_ID, make_mcp_server -@pytest.fixture(scope="module") -def mcp_server(): - server = FastMCP("FastMCP Test Server") +def test_register_and_unregister_toolgroup(llama_stack_client): + # TODO: make this work for http client also but you need to ensure + # the MCP server is reachable from llama stack server + if not isinstance(llama_stack_client, LlamaStackAsLibraryClient): + pytest.skip("The local MCP server only reliably reachable from library client.") - @server.tool() - async def fetch(url: str, ctx: Context) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: - headers = {"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"} - async with httpx.AsyncClient(follow_redirects=True, headers=headers) as client: - response = await client.get(url) - response.raise_for_status() - return [types.TextContent(type="text", text=response.text)] - - sse = SseServerTransport("/messages/") - - async def handle_sse(request): - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await server._mcp_server.run( - streams[0], - streams[1], - server._mcp_server.create_initialization_options(), - ) - - app = Starlette( - debug=True, - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ], - ) - - def get_open_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("", 0)) - return sock.getsockname()[1] - - port = get_open_port() - - def run_server(): - uvicorn.run(app, host="0.0.0.0", port=port) - - # Start the server in a new thread - server_thread = threading.Thread(target=run_server, daemon=True) - server_thread.start() - - # Polling until the server is ready - timeout = 10 - start_time = time.time() - - while time.time() - start_time < timeout: - try: - response = httpx.get(f"http://localhost:{port}/sse") - if response.status_code == 200: - break - except (httpx.RequestError, httpx.HTTPStatusError): - pass - time.sleep(0.1) - - yield port - - -def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server): - """ - Integration test for registering and unregistering a toolgroup using the ToolGroups API. - """ - port = mcp_server - test_toolgroup_id = "remote::web-fetch" + test_toolgroup_id = MCP_TOOLGROUP_ID provider_id = "model-context-protocol" - # Cleanup before running the test - toolgroups = llama_stack_client.toolgroups.list() - for toolgroup in toolgroups: - if toolgroup.identifier == test_toolgroup_id: - llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id) + with make_mcp_server() as mcp_server_info: + # Cleanup before running the test + toolgroups = llama_stack_client.toolgroups.list() + for toolgroup in toolgroups: + if toolgroup.identifier == test_toolgroup_id: + llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id) - # Register the toolgroup - llama_stack_client.toolgroups.register( - toolgroup_id=test_toolgroup_id, - provider_id=provider_id, - mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"), - ) + # Register the toolgroup + llama_stack_client.toolgroups.register( + toolgroup_id=test_toolgroup_id, + provider_id=provider_id, + mcp_endpoint=dict(uri=mcp_server_info["server_url"]), + ) - # Verify registration - registered_toolgroup = llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) - assert registered_toolgroup is not None - assert registered_toolgroup.identifier == test_toolgroup_id - assert registered_toolgroup.provider_id == provider_id + # Verify registration + registered_toolgroup = llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) + assert registered_toolgroup is not None + assert registered_toolgroup.identifier == test_toolgroup_id + assert registered_toolgroup.provider_id == provider_id - # Verify tools listing - tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) - assert isinstance(tools_list_response, list) - assert tools_list_response + # Verify tools listing + tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) + assert isinstance(tools_list_response, list) + assert tools_list_response - # Unregister the toolgroup - llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id) + # Unregister the toolgroup + llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id) - # Verify it is unregistered - with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"): - llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) + # Verify it is unregistered + with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"): + llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) - # Verify tools are also unregistered - unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) - assert isinstance(unregister_tools_list_response, list) - assert not unregister_tools_list_response + # Verify tools are also unregistered + unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) + assert isinstance(unregister_tools_list_response, list) + assert not unregister_tools_list_response