llama-stack-mirror/tests/integration/tool_runtime/test_mcp.py
Ashwin Bharambe 51945f1e57
feat: accept MCP authorization headers for MCP toolgroups (#2230)
The most interesting MCP servers are those with an authorization wall in
front of them. This PR uses the existing `provider_data` mechanism of
passing provider API keys for passing MCP access tokens (in fact,
arbitrary headers in the style of the OpenAI Responses API) from the
client through to the MCP server.

```
class MCPProviderDataValidator(BaseModel):
    # mcp_endpoint => list of headers to send
    mcp_headers: dict[str, list[str]] | None = None
```

Note how we must stuff the headers for all MCP endpoints into a single
"MCPProviderDataValidator". Unlike existing providers (e.g., Together
and Fireworks for inference) where we could name the provider api keys
clearly (`together_api_key`, `fireworks_api_key`), we cannot name these
keys for MCP. We have a single generic MCP provider which can serve
multiple "toolgroups". So we use a dict to combine all the headers for
all MCP endpoints you may want to use in an agentic call.


## Test Plan

See the added integration test for usage.
2025-05-23 08:52:18 -07:00

221 lines
6.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import socket
import threading
import time
import httpx
import mcp.types as types
import pytest
import uvicorn
from llama_stack_client import Agent
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.responses import Response
from starlette.routing import Mount, Route
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.distribution.datatypes import AuthenticationRequiredError
AUTH_TOKEN = "test-token"
@pytest.fixture(scope="module")
def mcp_server():
server = FastMCP("FastMCP Test Server")
@server.tool()
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
sse = SseServerTransport("/messages/")
async def handle_sse(request):
auth_header = request.headers.get("Authorization")
auth_token = None
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.split(" ")[1]
if auth_token != AUTH_TOKEN:
raise HTTPException(status_code=401, detail="Unauthorized")
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await server._mcp_server.run(
streams[0],
streams[1],
server._mcp_server.create_initialization_options(),
)
return Response()
app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
port = get_open_port()
config = uvicorn.Config(app, host="0.0.0.0", port=port)
server_instance = uvicorn.Server(config)
app.state.uvicorn_server = server_instance
def run_server():
server_instance.run()
# Start the server in a new thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Polling until the server is ready
timeout = 10
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = httpx.get(f"http://localhost:{port}/sse")
if response.status_code == 401:
break
except httpx.RequestError:
pass
time.sleep(0.1)
yield port
# Tell server to exit
server_instance.should_exit = True
server_thread.join(timeout=5)
def test_mcp_invocation(llama_stack_client, mcp_server):
port = mcp_server
test_toolgroup_id = "remote::mcptest"
# registering itself should fail since it requires listing tools
with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
)
provider_data = {
"mcp_headers": {
f"http://localhost:{port}/sse": [
f"Authorization: Bearer {AUTH_TOKEN}",
],
},
}
auth_headers = {
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
}
try:
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers)
except Exception as e:
# An error is OK since the toolgroup may not exist
print(f"Error unregistering toolgroup: {e}")
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
extra_headers=auth_headers,
)
response = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers,
)
assert len(response) == 1
assert response[0].identifier == "greet_everyone"
assert response[0].type == "tool"
assert len(response[0].parameters) == 1
p = response[0].parameters[0]
assert p.name == "url"
assert p.parameter_type == "string"
assert p.required
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name=response[0].identifier,
kwargs=dict(url="https://www.google.com"),
extra_headers=auth_headers,
)
content = response.content
assert len(content) == 1
assert content[0].type == "text"
assert content[0].text == "Hello, world!"
models = llama_stack_client.models.list()
model_id = models[0].identifier
print(f"Using model: {model_id}")
agent = Agent(
client=llama_stack_client,
model=model_id,
instructions="You are a helpful assistant.",
tools=[test_toolgroup_id],
)
session_id = agent.create_session("test-session")
response = agent.create_turn(
session_id=session_id,
messages=[
{
"role": "user",
"content": "Yo. Use tools.",
}
],
stream=False,
extra_headers=auth_headers,
)
steps = response.steps
first = steps[0]
assert first.step_type == "inference"
assert len(first.api_model_response.tool_calls) == 1
tool_call = first.api_model_response.tool_calls[0]
assert tool_call.tool_name == "greet_everyone"
second = steps[1]
assert second.step_type == "tool_execution"
tool_response_content = second.tool_responses[0].content
assert len(tool_response_content) == 1
assert tool_response_content[0].type == "text"
assert tool_response_content[0].text == "Hello, world!"
third = steps[2]
assert third.step_type == "inference"
assert len(third.api_model_response.tool_calls) == 0
# when streaming, we currently don't check auth headers upfront and fail the request
# early. but we should at least be generating a 401 later in the process.
response = agent.create_turn(
session_id=session_id,
messages=[
{
"role": "user",
"content": "Yo. Use tools.",
}
],
stream=True,
)
if isinstance(llama_stack_client, LlamaStackAsLibraryClient):
with pytest.raises(AuthenticationRequiredError):
for _ in response:
pass
else:
error_chunks = [chunk for chunk in response if "error" in chunk.model_dump()]
assert len(error_chunks) == 1
chunk = error_chunks[0].model_dump()
assert "Unauthorized" in chunk["error"]["message"]