mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: accept MCP authorization headers for MCP toolgroups (#2230)
The most interesting MCP servers are those with an authorization wall in front of them. This PR uses the existing `provider_data` mechanism of passing provider API keys for passing MCP access tokens (in fact, arbitrary headers in the style of the OpenAI Responses API) from the client through to the MCP server. ``` class MCPProviderDataValidator(BaseModel): # mcp_endpoint => list of headers to send mcp_headers: dict[str, list[str]] | None = None ``` Note how we must stuff the headers for all MCP endpoints into a single "MCPProviderDataValidator". Unlike existing providers (e.g., Together and Fireworks for inference) where we could name the provider api keys clearly (`together_api_key`, `fireworks_api_key`), we cannot name these keys for MCP. We have a single generic MCP provider which can serve multiple "toolgroups". So we use a dict to combine all the headers for all MCP endpoints you may want to use in an agentic call. ## Test Plan See the added integration test for usage.
This commit is contained in:
parent
2708312168
commit
51945f1e57
9 changed files with 331 additions and 47 deletions
|
@ -76,8 +76,8 @@ class ToolInvocationResult(BaseModel):
|
|||
|
||||
|
||||
class ToolStore(Protocol):
|
||||
def get_tool(self, tool_name: str) -> Tool: ...
|
||||
def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
||||
async def get_tool(self, tool_name: str) -> Tool: ...
|
||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
|
||||
|
||||
|
||||
class ListToolGroupsResponse(BaseModel):
|
||||
|
|
|
@ -236,6 +236,10 @@ class AuthenticationConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class AuthenticationRequiredError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class QuotaPeriod(str, Enum):
|
||||
DAY = "day"
|
||||
|
||||
|
|
|
@ -261,9 +261,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
raise ValueError("Client not initialized")
|
||||
|
||||
# Create headers with provider data if available
|
||||
headers = {}
|
||||
headers = options.headers or {}
|
||||
if self.provider_data:
|
||||
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
|
||||
keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"]
|
||||
if all(key not in headers for key in keys):
|
||||
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
|
||||
|
||||
# Use context manager for provider data
|
||||
with request_provider_data_context(headers):
|
||||
|
|
|
@ -28,7 +28,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
|
@ -122,6 +122,8 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
||||
elif isinstance(exc, NotImplementedError):
|
||||
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
||||
elif isinstance(exc, AuthenticationRequiredError):
|
||||
return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}")
|
||||
else:
|
||||
return HTTPException(
|
||||
status_code=500,
|
||||
|
|
|
@ -80,8 +80,9 @@ def available_providers() -> list[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="model-context-protocol",
|
||||
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
|
||||
pip_packages=["mcp"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -4,18 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import ModelContextProtocolConfig
|
||||
from .config import MCPProviderConfig
|
||||
|
||||
|
||||
class ModelContextProtocolToolProviderDataValidator(BaseModel):
|
||||
api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
|
||||
async def get_adapter_impl(config: MCPProviderConfig, _deps):
|
||||
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
|
||||
|
||||
impl = ModelContextProtocolToolRuntimeImpl(config)
|
||||
impl = ModelContextProtocolToolRuntimeImpl(config, _deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -9,7 +9,12 @@ from typing import Any
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelContextProtocolConfig(BaseModel):
|
||||
class MCPProviderDataValidator(BaseModel):
|
||||
# mcp_endpoint => list of headers to send
|
||||
mcp_headers: dict[str, list[str]] | None = None
|
||||
|
||||
|
||||
class MCPProviderConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
|
|
|
@ -4,13 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import exceptiongroup
|
||||
import httpx
|
||||
from mcp import ClientSession
|
||||
from mcp import types as mcp_types
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
ToolDef,
|
||||
|
@ -18,13 +23,36 @@ from llama_stack.apis.tools import (
|
|||
ToolParameter,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||
|
||||
from .config import ModelContextProtocolConfig
|
||||
from .config import MCPProviderConfig
|
||||
|
||||
logger = get_logger(__name__, category="tools")
|
||||
|
||||
|
||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||
def __init__(self, config: ModelContextProtocolConfig):
|
||||
@asynccontextmanager
|
||||
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
|
||||
try:
|
||||
async with sse_client(endpoint, headers=headers) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
except BaseException as e:
|
||||
if isinstance(e, exceptiongroup.BaseExceptionGroup):
|
||||
for exc in e.exceptions:
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(exc) from exc
|
||||
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(e) from e
|
||||
|
||||
raise
|
||||
|
||||
|
||||
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
|
||||
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
|
@ -33,34 +61,34 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
# this endpoint should be retrieved by getting the tool group right?
|
||||
if mcp_endpoint is None:
|
||||
raise ValueError("mcp_endpoint is required")
|
||||
|
||||
headers = await self.get_headers_from_request(mcp_endpoint.uri)
|
||||
tools = []
|
||||
async with sse_client(mcp_endpoint.uri) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
parameters = []
|
||||
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=param_name,
|
||||
parameter_type=param_schema.get("type", "string"),
|
||||
description=param_schema.get("description", ""),
|
||||
)
|
||||
)
|
||||
tools.append(
|
||||
ToolDef(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=parameters,
|
||||
metadata={
|
||||
"endpoint": mcp_endpoint.uri,
|
||||
},
|
||||
async with sse_client_wrapper(mcp_endpoint.uri, headers) as session:
|
||||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
parameters = []
|
||||
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
name=param_name,
|
||||
parameter_type=param_schema.get("type", "string"),
|
||||
description=param_schema.get("description", ""),
|
||||
)
|
||||
)
|
||||
tools.append(
|
||||
ToolDef(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=parameters,
|
||||
metadata={
|
||||
"endpoint": mcp_endpoint.uri,
|
||||
},
|
||||
)
|
||||
)
|
||||
return ListToolDefsResponse(data=tools)
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
|
@ -71,12 +99,39 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
if urlparse(endpoint).scheme not in ("http", "https"):
|
||||
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
|
||||
|
||||
async with sse_client(endpoint) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
result = await session.call_tool(tool.identifier, kwargs)
|
||||
headers = await self.get_headers_from_request(endpoint)
|
||||
async with sse_client_wrapper(endpoint, headers) as session:
|
||||
result = await session.call_tool(tool.identifier, kwargs)
|
||||
|
||||
content = []
|
||||
for item in result.content:
|
||||
if isinstance(item, mcp_types.TextContent):
|
||||
content.append(TextContentItem(text=item.text))
|
||||
elif isinstance(item, mcp_types.ImageContent):
|
||||
content.append(ImageContentItem(image=item.data))
|
||||
elif isinstance(item, mcp_types.EmbeddedResource):
|
||||
logger.warning(f"EmbeddedResource is not supported: {item}")
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {type(item)}")
|
||||
return ToolInvocationResult(
|
||||
content="\n".join([result.model_dump_json() for result in result.content]),
|
||||
content=content,
|
||||
error_code=1 if result.isError else 0,
|
||||
)
|
||||
|
||||
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
|
||||
def canonicalize_uri(uri: str) -> str:
|
||||
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
|
||||
|
||||
headers = {}
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data and provider_data.mcp_headers:
|
||||
for uri, values in provider_data.mcp_headers.items():
|
||||
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
|
||||
continue
|
||||
for entry in values:
|
||||
parts = entry.split(":")
|
||||
if len(parts) == 2:
|
||||
k, v = parts
|
||||
headers[k.strip()] = v.strip()
|
||||
return headers
|
||||
|
|
221
tests/integration/tool_runtime/test_mcp.py
Normal file
221
tests/integration/tool_runtime/test_mcp.py
Normal file
|
@ -0,0 +1,221 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
|
||||
import httpx
|
||||
import mcp.types as types
|
||||
import pytest
|
||||
import uvicorn
|
||||
from llama_stack_client import Agent
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from starlette.applications import Starlette
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||
|
||||
AUTH_TOKEN = "test-token"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mcp_server():
|
||||
server = FastMCP("FastMCP Test Server")
|
||||
|
||||
@server.tool()
|
||||
async def greet_everyone(
|
||||
url: str, ctx: Context
|
||||
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||
return [types.TextContent(type="text", text="Hello, world!")]
|
||||
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
async def handle_sse(request):
|
||||
auth_header = request.headers.get("Authorization")
|
||||
auth_token = None
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
auth_token = auth_header.split(" ")[1]
|
||||
|
||||
if auth_token != AUTH_TOKEN:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
||||
await server._mcp_server.run(
|
||||
streams[0],
|
||||
streams[1],
|
||||
server._mcp_server.create_initialization_options(),
|
||||
)
|
||||
return Response()
|
||||
|
||||
app = Starlette(
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
],
|
||||
)
|
||||
|
||||
def get_open_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("", 0))
|
||||
return sock.getsockname()[1]
|
||||
|
||||
port = get_open_port()
|
||||
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=port)
|
||||
server_instance = uvicorn.Server(config)
|
||||
app.state.uvicorn_server = server_instance
|
||||
|
||||
def run_server():
|
||||
server_instance.run()
|
||||
|
||||
# Start the server in a new thread
|
||||
server_thread = threading.Thread(target=run_server, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
# Polling until the server is ready
|
||||
timeout = 10
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = httpx.get(f"http://localhost:{port}/sse")
|
||||
if response.status_code == 401:
|
||||
break
|
||||
except httpx.RequestError:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
|
||||
yield port
|
||||
|
||||
# Tell server to exit
|
||||
server_instance.should_exit = True
|
||||
server_thread.join(timeout=5)
|
||||
|
||||
|
||||
def test_mcp_invocation(llama_stack_client, mcp_server):
|
||||
port = mcp_server
|
||||
test_toolgroup_id = "remote::mcptest"
|
||||
|
||||
# registering itself should fail since it requires listing tools
|
||||
with pytest.raises(Exception, match="Unauthorized"):
|
||||
llama_stack_client.toolgroups.register(
|
||||
toolgroup_id=test_toolgroup_id,
|
||||
provider_id="model-context-protocol",
|
||||
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
|
||||
)
|
||||
|
||||
provider_data = {
|
||||
"mcp_headers": {
|
||||
f"http://localhost:{port}/sse": [
|
||||
f"Authorization: Bearer {AUTH_TOKEN}",
|
||||
],
|
||||
},
|
||||
}
|
||||
auth_headers = {
|
||||
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
|
||||
}
|
||||
|
||||
try:
|
||||
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers)
|
||||
except Exception as e:
|
||||
# An error is OK since the toolgroup may not exist
|
||||
print(f"Error unregistering toolgroup: {e}")
|
||||
|
||||
llama_stack_client.toolgroups.register(
|
||||
toolgroup_id=test_toolgroup_id,
|
||||
provider_id="model-context-protocol",
|
||||
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
|
||||
extra_headers=auth_headers,
|
||||
)
|
||||
response = llama_stack_client.tools.list(
|
||||
toolgroup_id=test_toolgroup_id,
|
||||
extra_headers=auth_headers,
|
||||
)
|
||||
assert len(response) == 1
|
||||
assert response[0].identifier == "greet_everyone"
|
||||
assert response[0].type == "tool"
|
||||
assert len(response[0].parameters) == 1
|
||||
p = response[0].parameters[0]
|
||||
assert p.name == "url"
|
||||
assert p.parameter_type == "string"
|
||||
assert p.required
|
||||
|
||||
response = llama_stack_client.tool_runtime.invoke_tool(
|
||||
tool_name=response[0].identifier,
|
||||
kwargs=dict(url="https://www.google.com"),
|
||||
extra_headers=auth_headers,
|
||||
)
|
||||
content = response.content
|
||||
assert len(content) == 1
|
||||
assert content[0].type == "text"
|
||||
assert content[0].text == "Hello, world!"
|
||||
|
||||
models = llama_stack_client.models.list()
|
||||
model_id = models[0].identifier
|
||||
print(f"Using model: {model_id}")
|
||||
agent = Agent(
|
||||
client=llama_stack_client,
|
||||
model=model_id,
|
||||
instructions="You are a helpful assistant.",
|
||||
tools=[test_toolgroup_id],
|
||||
)
|
||||
session_id = agent.create_session("test-session")
|
||||
response = agent.create_turn(
|
||||
session_id=session_id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Yo. Use tools.",
|
||||
}
|
||||
],
|
||||
stream=False,
|
||||
extra_headers=auth_headers,
|
||||
)
|
||||
steps = response.steps
|
||||
first = steps[0]
|
||||
assert first.step_type == "inference"
|
||||
assert len(first.api_model_response.tool_calls) == 1
|
||||
tool_call = first.api_model_response.tool_calls[0]
|
||||
assert tool_call.tool_name == "greet_everyone"
|
||||
|
||||
second = steps[1]
|
||||
assert second.step_type == "tool_execution"
|
||||
tool_response_content = second.tool_responses[0].content
|
||||
assert len(tool_response_content) == 1
|
||||
assert tool_response_content[0].type == "text"
|
||||
assert tool_response_content[0].text == "Hello, world!"
|
||||
|
||||
third = steps[2]
|
||||
assert third.step_type == "inference"
|
||||
assert len(third.api_model_response.tool_calls) == 0
|
||||
|
||||
# when streaming, we currently don't check auth headers upfront and fail the request
|
||||
# early. but we should at least be generating a 401 later in the process.
|
||||
response = agent.create_turn(
|
||||
session_id=session_id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Yo. Use tools.",
|
||||
}
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
if isinstance(llama_stack_client, LlamaStackAsLibraryClient):
|
||||
with pytest.raises(AuthenticationRequiredError):
|
||||
for _ in response:
|
||||
pass
|
||||
else:
|
||||
error_chunks = [chunk for chunk in response if "error" in chunk.model_dump()]
|
||||
assert len(error_chunks) == 1
|
||||
chunk = error_chunks[0].model_dump()
|
||||
assert "Unauthorized" in chunk["error"]["message"]
|
Loading…
Add table
Add a link
Reference in a new issue