# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import json import socket import threading import time import httpx import mcp.types as types import pytest import uvicorn from llama_stack_client import Agent from mcp.server.fastmcp import Context, FastMCP from mcp.server.sse import SseServerTransport from starlette.applications import Starlette from starlette.exceptions import HTTPException from starlette.responses import Response from starlette.routing import Mount, Route AUTH_TOKEN = "test-token" @pytest.fixture(scope="module") def mcp_server(): server = FastMCP("FastMCP Test Server") @server.tool() async def greet_everyone( url: str, ctx: Context ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: return [types.TextContent(type="text", text="Hello, world!")] sse = SseServerTransport("/messages/") async def handle_sse(request): auth_header = request.headers.get("Authorization") auth_token = None if auth_header and auth_header.startswith("Bearer "): auth_token = auth_header.split(" ")[1] if auth_token != AUTH_TOKEN: raise HTTPException(status_code=401, detail="Unauthorized") async with sse.connect_sse(request.scope, request.receive, request._send) as streams: await server._mcp_server.run( streams[0], streams[1], server._mcp_server.create_initialization_options(), ) return Response() app = Starlette( routes=[ Route("/sse", endpoint=handle_sse), Mount("/messages/", app=sse.handle_post_message), ], ) def get_open_port(): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.bind(("", 0)) return sock.getsockname()[1] port = get_open_port() config = uvicorn.Config(app, host="0.0.0.0", port=port) server_instance = uvicorn.Server(config) app.state.uvicorn_server = server_instance def run_server(): server_instance.run() # Start the server in a new thread server_thread = threading.Thread(target=run_server, daemon=True) server_thread.start() # Polling until the server is ready timeout = 10 start_time = time.time() while time.time() - start_time < timeout: try: response = httpx.get(f"http://localhost:{port}/sse") if response.status_code == 401: break except httpx.RequestError: pass time.sleep(0.1) yield port # Tell server to exit server_instance.should_exit = True server_thread.join(timeout=5) def test_mcp_invocation(llama_stack_client, mcp_server): port = mcp_server test_toolgroup_id = "remote::mcptest" # registering itself should fail since it requires listing tools with pytest.raises(Exception, match="Unauthorized"): llama_stack_client.toolgroups.register( toolgroup_id=test_toolgroup_id, provider_id="model-context-protocol", mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"), ) provider_data = { "mcp_headers": { f"http://localhost:{port}/sse": [ f"Authorization: Bearer {AUTH_TOKEN}", ], }, } auth_headers = { "X-LlamaStack-Provider-Data": json.dumps(provider_data), } llama_stack_client.toolgroups.register( toolgroup_id=test_toolgroup_id, provider_id="model-context-protocol", mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"), extra_headers=auth_headers, ) response = llama_stack_client.tools.list( toolgroup_id=test_toolgroup_id, extra_headers=auth_headers, ) assert len(response) == 1 assert response[0].identifier == "greet_everyone" assert response[0].type == "tool" assert len(response[0].parameters) == 1 p = response[0].parameters[0] assert p.name == "url" assert p.parameter_type == "string" assert p.required response = llama_stack_client.tool_runtime.invoke_tool( tool_name=response[0].identifier, kwargs=dict(url="https://www.google.com"), extra_headers=auth_headers, ) content = response.content assert len(content) == 1 assert content[0].type == "text" assert content[0].text == "Hello, world!" models = llama_stack_client.models.list() model_id = models[0].identifier print(f"Using model: {model_id}") agent = Agent( client=llama_stack_client, model=model_id, instructions="You are a helpful assistant.", tools=[test_toolgroup_id], ) session_id = agent.create_session("test-session") response = agent.create_turn( session_id=session_id, messages=[ { "role": "user", "content": "Yo. Use tools.", } ], stream=False, extra_headers=auth_headers, ) steps = response.steps first = steps[0] assert first.step_type == "inference" assert len(first.api_model_response.tool_calls) == 1 tool_call = first.api_model_response.tool_calls[0] assert tool_call.tool_name == "greet_everyone" second = steps[1] assert second.step_type == "tool_execution" tool_response_content = second.tool_responses[0].content assert len(tool_response_content) == 1 assert tool_response_content[0].type == "text" assert tool_response_content[0].text == "Hello, world!" third = steps[2] assert third.step_type == "inference" assert len(third.api_model_response.tool_calls) == 0