llama-stack-mirror/tests/integration/tool_runtime/test_mcp.py
2025-05-23 08:41:55 -07:00

221 lines
6.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import socket
import threading
import time
import httpx
import mcp.types as types
import pytest
import uvicorn
from llama_stack_client import Agent
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.responses import Response
from starlette.routing import Mount, Route
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.distribution.datatypes import AuthenticationRequiredError
AUTH_TOKEN = "test-token"
@pytest.fixture(scope="module")
def mcp_server():
server = FastMCP("FastMCP Test Server")
@server.tool()
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
sse = SseServerTransport("/messages/")
async def handle_sse(request):
auth_header = request.headers.get("Authorization")
auth_token = None
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.split(" ")[1]
if auth_token != AUTH_TOKEN:
raise HTTPException(status_code=401, detail="Unauthorized")
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await server._mcp_server.run(
streams[0],
streams[1],
server._mcp_server.create_initialization_options(),
)
return Response()
app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
port = get_open_port()
config = uvicorn.Config(app, host="0.0.0.0", port=port)
server_instance = uvicorn.Server(config)
app.state.uvicorn_server = server_instance
def run_server():
server_instance.run()
# Start the server in a new thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Polling until the server is ready
timeout = 10
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = httpx.get(f"http://localhost:{port}/sse")
if response.status_code == 401:
break
except httpx.RequestError:
pass
time.sleep(0.1)
yield port
# Tell server to exit
server_instance.should_exit = True
server_thread.join(timeout=5)
def test_mcp_invocation(llama_stack_client, mcp_server):
port = mcp_server
test_toolgroup_id = "remote::mcptest"
# registering itself should fail since it requires listing tools
with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
)
provider_data = {
"mcp_headers": {
f"http://localhost:{port}/sse": [
f"Authorization: Bearer {AUTH_TOKEN}",
],
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
try:
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers)
except Exception as e:
# An error is OK since the toolgroup may not exist
print(f"Error unregistering toolgroup: {e}")
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
extra_headers=auth_headers,
)
response = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers,
)
assert len(response) == 1
assert response[0].identifier == "greet_everyone"
assert response[0].type == "tool"
assert len(response[0].parameters) == 1
p = response[0].parameters[0]
assert p.name == "url"
assert p.parameter_type == "string"
assert p.required
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name=response[0].identifier,
kwargs=dict(url="https://www.google.com"),
extra_headers=auth_headers,
)
content = response.content
assert len(content) == 1
assert content[0].type == "text"
assert content[0].text == "Hello, world!"
models = llama_stack_client.models.list()
model_id = models[0].identifier
print(f"Using model: {model_id}")
agent = Agent(
client=llama_stack_client,
model=model_id,
instructions="You are a helpful assistant.",
tools=[test_toolgroup_id],
)
session_id = agent.create_session("test-session")
response = agent.create_turn(
session_id=session_id,
messages=[
{
"role": "user",
"content": "Yo. Use tools.",
}
],
stream=False,
extra_headers=auth_headers,
)
steps = response.steps
first = steps[0]
assert first.step_type == "inference"
assert len(first.api_model_response.tool_calls) == 1
tool_call = first.api_model_response.tool_calls[0]
assert tool_call.tool_name == "greet_everyone"
second = steps[1]
assert second.step_type == "tool_execution"
tool_response_content = second.tool_responses[0].content
assert len(tool_response_content) == 1
assert tool_response_content[0].type == "text"
assert tool_response_content[0].text == "Hello, world!"
third = steps[2]
assert third.step_type == "inference"
assert len(third.api_model_response.tool_calls) == 0
# when streaming, we currently don't check auth headers upfront and fail the request
# early. but we should at least be generating a 401 later in the process.
response = agent.create_turn(
session_id=session_id,
messages=[
{
"role": "user",
"content": "Yo. Use tools.",
}
],
stream=True,
)
if isinstance(llama_stack_client, LlamaStackAsLibraryClient):
with pytest.raises(AuthenticationRequiredError):
for _ in response:
pass
else:
error_chunks = [chunk for chunk in response if "error" in chunk.model_dump()]
assert len(error_chunks) == 1
chunk = error_chunks[0].model_dump()
assert "Unauthorized" in chunk["error"]["message"]