mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
The most interesting MCP servers are those with an authorization wall in front of them. This PR uses the existing `provider_data` mechanism of passing provider API keys for passing MCP access tokens (in fact, arbitrary headers in the style of the OpenAI Responses API) from the client through to the MCP server. ``` class MCPProviderDataValidator(BaseModel): # mcp_endpoint => list of headers to send mcp_headers: dict[str, list[str]] | None = None ``` Note how we must stuff the headers for all MCP endpoints into a single "MCPProviderDataValidator". Unlike existing providers (e.g., Together and Fireworks for inference) where we could name the provider api keys clearly (`together_api_key`, `fireworks_api_key`), we cannot name these keys for MCP. We have a single generic MCP provider which can serve multiple "toolgroups". So we use a dict to combine all the headers for all MCP endpoints you may want to use in an agentic call. ## Test Plan See the added integration test for usage.
221 lines
6.9 KiB
Python
221 lines
6.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
import socket
|
|
import threading
|
|
import time
|
|
|
|
import httpx
|
|
import mcp.types as types
|
|
import pytest
|
|
import uvicorn
|
|
from llama_stack_client import Agent
|
|
from mcp.server.fastmcp import Context, FastMCP
|
|
from mcp.server.sse import SseServerTransport
|
|
from starlette.applications import Starlette
|
|
from starlette.exceptions import HTTPException
|
|
from starlette.responses import Response
|
|
from starlette.routing import Mount, Route
|
|
|
|
from llama_stack import LlamaStackAsLibraryClient
|
|
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
|
|
|
AUTH_TOKEN = "test-token"
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def mcp_server():
|
|
server = FastMCP("FastMCP Test Server")
|
|
|
|
@server.tool()
|
|
async def greet_everyone(
|
|
url: str, ctx: Context
|
|
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
|
return [types.TextContent(type="text", text="Hello, world!")]
|
|
|
|
sse = SseServerTransport("/messages/")
|
|
|
|
async def handle_sse(request):
|
|
auth_header = request.headers.get("Authorization")
|
|
auth_token = None
|
|
if auth_header and auth_header.startswith("Bearer "):
|
|
auth_token = auth_header.split(" ")[1]
|
|
|
|
if auth_token != AUTH_TOKEN:
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
|
await server._mcp_server.run(
|
|
streams[0],
|
|
streams[1],
|
|
server._mcp_server.create_initialization_options(),
|
|
)
|
|
return Response()
|
|
|
|
app = Starlette(
|
|
routes=[
|
|
Route("/sse", endpoint=handle_sse),
|
|
Mount("/messages/", app=sse.handle_post_message),
|
|
],
|
|
)
|
|
|
|
def get_open_port():
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
|
sock.bind(("", 0))
|
|
return sock.getsockname()[1]
|
|
|
|
port = get_open_port()
|
|
|
|
config = uvicorn.Config(app, host="0.0.0.0", port=port)
|
|
server_instance = uvicorn.Server(config)
|
|
app.state.uvicorn_server = server_instance
|
|
|
|
def run_server():
|
|
server_instance.run()
|
|
|
|
# Start the server in a new thread
|
|
server_thread = threading.Thread(target=run_server, daemon=True)
|
|
server_thread.start()
|
|
|
|
# Polling until the server is ready
|
|
timeout = 10
|
|
start_time = time.time()
|
|
|
|
while time.time() - start_time < timeout:
|
|
try:
|
|
response = httpx.get(f"http://localhost:{port}/sse")
|
|
if response.status_code == 401:
|
|
break
|
|
except httpx.RequestError:
|
|
pass
|
|
time.sleep(0.1)
|
|
|
|
yield port
|
|
|
|
# Tell server to exit
|
|
server_instance.should_exit = True
|
|
server_thread.join(timeout=5)
|
|
|
|
|
|
def test_mcp_invocation(llama_stack_client, mcp_server):
|
|
port = mcp_server
|
|
test_toolgroup_id = "remote::mcptest"
|
|
|
|
# registering itself should fail since it requires listing tools
|
|
with pytest.raises(Exception, match="Unauthorized"):
|
|
llama_stack_client.toolgroups.register(
|
|
toolgroup_id=test_toolgroup_id,
|
|
provider_id="model-context-protocol",
|
|
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
|
|
)
|
|
|
|
provider_data = {
|
|
"mcp_headers": {
|
|
f"http://localhost:{port}/sse": [
|
|
f"Authorization: Bearer {AUTH_TOKEN}",
|
|
],
|
|
},
|
|
}
|
|
auth_headers = {
|
|
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
|
|
}
|
|
|
|
try:
|
|
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers)
|
|
except Exception as e:
|
|
# An error is OK since the toolgroup may not exist
|
|
print(f"Error unregistering toolgroup: {e}")
|
|
|
|
llama_stack_client.toolgroups.register(
|
|
toolgroup_id=test_toolgroup_id,
|
|
provider_id="model-context-protocol",
|
|
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
|
|
extra_headers=auth_headers,
|
|
)
|
|
response = llama_stack_client.tools.list(
|
|
toolgroup_id=test_toolgroup_id,
|
|
extra_headers=auth_headers,
|
|
)
|
|
assert len(response) == 1
|
|
assert response[0].identifier == "greet_everyone"
|
|
assert response[0].type == "tool"
|
|
assert len(response[0].parameters) == 1
|
|
p = response[0].parameters[0]
|
|
assert p.name == "url"
|
|
assert p.parameter_type == "string"
|
|
assert p.required
|
|
|
|
response = llama_stack_client.tool_runtime.invoke_tool(
|
|
tool_name=response[0].identifier,
|
|
kwargs=dict(url="https://www.google.com"),
|
|
extra_headers=auth_headers,
|
|
)
|
|
content = response.content
|
|
assert len(content) == 1
|
|
assert content[0].type == "text"
|
|
assert content[0].text == "Hello, world!"
|
|
|
|
models = llama_stack_client.models.list()
|
|
model_id = models[0].identifier
|
|
print(f"Using model: {model_id}")
|
|
agent = Agent(
|
|
client=llama_stack_client,
|
|
model=model_id,
|
|
instructions="You are a helpful assistant.",
|
|
tools=[test_toolgroup_id],
|
|
)
|
|
session_id = agent.create_session("test-session")
|
|
response = agent.create_turn(
|
|
session_id=session_id,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": "Yo. Use tools.",
|
|
}
|
|
],
|
|
stream=False,
|
|
extra_headers=auth_headers,
|
|
)
|
|
steps = response.steps
|
|
first = steps[0]
|
|
assert first.step_type == "inference"
|
|
assert len(first.api_model_response.tool_calls) == 1
|
|
tool_call = first.api_model_response.tool_calls[0]
|
|
assert tool_call.tool_name == "greet_everyone"
|
|
|
|
second = steps[1]
|
|
assert second.step_type == "tool_execution"
|
|
tool_response_content = second.tool_responses[0].content
|
|
assert len(tool_response_content) == 1
|
|
assert tool_response_content[0].type == "text"
|
|
assert tool_response_content[0].text == "Hello, world!"
|
|
|
|
third = steps[2]
|
|
assert third.step_type == "inference"
|
|
assert len(third.api_model_response.tool_calls) == 0
|
|
|
|
# when streaming, we currently don't check auth headers upfront and fail the request
|
|
# early. but we should at least be generating a 401 later in the process.
|
|
response = agent.create_turn(
|
|
session_id=session_id,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": "Yo. Use tools.",
|
|
}
|
|
],
|
|
stream=True,
|
|
)
|
|
if isinstance(llama_stack_client, LlamaStackAsLibraryClient):
|
|
with pytest.raises(AuthenticationRequiredError):
|
|
for _ in response:
|
|
pass
|
|
else:
|
|
error_chunks = [chunk for chunk in response if "error" in chunk.model_dump()]
|
|
assert len(error_chunks) == 1
|
|
chunk = error_chunks[0].model_dump()
|
|
assert "Unauthorized" in chunk["error"]["message"]
|