mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
# What does this PR do? Adding a user-facing `authorization ` parameter to MCP tool definitions that allows users to explicitly configure credentials per MCP server, addressing GitHub Issue #4034 in a secure manner. ## Test Plan tests/integration/responses/test_mcp_authentication.py --------- Co-authored-by: Omar Abdelwahab <omara@fb.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
229 lines
9.1 KiB
Python
229 lines
9.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from collections.abc import AsyncGenerator
|
|
from contextlib import asynccontextmanager
|
|
from enum import Enum
|
|
from typing import Any, cast
|
|
|
|
import httpx
|
|
from mcp import ClientSession, McpError
|
|
from mcp import types as mcp_types
|
|
from mcp.client.sse import sse_client
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|
|
|
from llama_stack.core.datatypes import AuthenticationRequiredError
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.utils.tools.ttl_dict import TTLDict
|
|
from llama_stack_api import (
|
|
ImageContentItem,
|
|
InterleavedContentItem,
|
|
ListToolDefsResponse,
|
|
TextContentItem,
|
|
ToolDef,
|
|
ToolInvocationResult,
|
|
_URLOrData,
|
|
)
|
|
|
|
logger = get_logger(__name__, category="tools")
|
|
|
|
|
|
def prepare_mcp_headers(base_headers: dict[str, str] | None, authorization: str | None) -> dict[str, str]:
|
|
"""
|
|
Prepare headers for MCP requests with authorization support.
|
|
|
|
Args:
|
|
base_headers: Base headers dictionary (can be None)
|
|
authorization: OAuth access token (without "Bearer " prefix)
|
|
|
|
Returns:
|
|
Headers dictionary with Authorization header if token provided
|
|
|
|
Raises:
|
|
ValueError: If Authorization header is specified in the headers dict (security risk)
|
|
"""
|
|
headers = dict(base_headers or {})
|
|
|
|
# Security check: reject any Authorization header in the headers dict
|
|
# Users must use the authorization parameter instead to avoid security risks
|
|
existing_keys_lower = {k.lower() for k in headers.keys()}
|
|
if "authorization" in existing_keys_lower:
|
|
raise ValueError(
|
|
"For security reasons, Authorization header cannot be passed via 'headers'. "
|
|
"Please use the 'authorization' parameter instead."
|
|
)
|
|
|
|
# Add Authorization header if token provided
|
|
if authorization:
|
|
# OAuth access token - add "Bearer " prefix
|
|
headers["Authorization"] = f"Bearer {authorization}"
|
|
|
|
return headers
|
|
|
|
|
|
protocol_cache = TTLDict(ttl_seconds=3600)
|
|
|
|
|
|
class MCPProtol(Enum):
|
|
UNKNOWN = 0
|
|
STREAMABLE_HTTP = 1
|
|
SSE = 2
|
|
|
|
|
|
@asynccontextmanager
|
|
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
|
|
# we use a ttl'd dict to cache the happy path protocol for each endpoint
|
|
# but, we always fall back to trying the other protocol if we cannot initialize the session
|
|
connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE]
|
|
mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN)
|
|
if mcp_protocol == MCPProtol.SSE:
|
|
connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP]
|
|
|
|
for i, strategy in enumerate(connection_strategies):
|
|
try:
|
|
client = streamablehttp_client
|
|
if strategy == MCPProtol.SSE:
|
|
# sse_client and streamablehttp_client have different signatures, but both
|
|
# are called the same way here, so we cast to Any to avoid type errors
|
|
client = cast(Any, sse_client)
|
|
async with client(endpoint, headers=headers) as client_streams:
|
|
async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session:
|
|
await session.initialize()
|
|
protocol_cache[endpoint] = strategy
|
|
yield session
|
|
return
|
|
except* httpx.HTTPStatusError as eg:
|
|
for exc in eg.exceptions:
|
|
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
|
|
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
|
|
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
|
|
err = cast(httpx.HTTPStatusError, exc)
|
|
if err.response.status_code == 401:
|
|
raise AuthenticationRequiredError(exc) from exc
|
|
if i == len(connection_strategies) - 1:
|
|
raise
|
|
except* httpx.ConnectError as eg:
|
|
# Connection refused, server down, network unreachable
|
|
if i == len(connection_strategies) - 1:
|
|
error_msg = f"Failed to connect to MCP server at {endpoint}: Connection refused"
|
|
logger.error(f"MCP connection error: {error_msg}")
|
|
raise ConnectionError(error_msg) from eg
|
|
else:
|
|
logger.warning(
|
|
f"failed to connect to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
|
)
|
|
except* httpx.TimeoutException as eg:
|
|
# Request timeout, server too slow
|
|
if i == len(connection_strategies) - 1:
|
|
error_msg = f"MCP server at {endpoint} timed out"
|
|
logger.error(f"MCP timeout error: {error_msg}")
|
|
raise TimeoutError(error_msg) from eg
|
|
else:
|
|
logger.warning(
|
|
f"MCP server at {endpoint} timed out via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
|
)
|
|
except* httpx.RequestError as eg:
|
|
# DNS resolution failures, network errors, invalid URLs
|
|
if i == len(connection_strategies) - 1:
|
|
# Get the first exception's message for the error string
|
|
exc_msg = str(eg.exceptions[0]) if eg.exceptions else "Unknown error"
|
|
error_msg = f"Network error connecting to MCP server at {endpoint}: {exc_msg}"
|
|
logger.error(f"MCP network error: {error_msg}")
|
|
raise ConnectionError(error_msg) from eg
|
|
else:
|
|
logger.warning(
|
|
f"network error connecting to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
|
)
|
|
except* McpError:
|
|
if i < len(connection_strategies) - 1:
|
|
logger.warning(
|
|
f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
|
)
|
|
else:
|
|
raise
|
|
|
|
|
|
async def list_mcp_tools(
|
|
endpoint: str,
|
|
headers: dict[str, str] | None = None,
|
|
authorization: str | None = None,
|
|
) -> ListToolDefsResponse:
|
|
"""List tools available from an MCP server.
|
|
|
|
Args:
|
|
endpoint: MCP server endpoint URL
|
|
headers: Optional base headers to include
|
|
authorization: Optional OAuth access token (just the token, not "Bearer <token>")
|
|
|
|
Returns:
|
|
List of tool definitions from the MCP server
|
|
|
|
Raises:
|
|
ValueError: If Authorization is found in the headers parameter
|
|
"""
|
|
# Prepare headers with authorization handling
|
|
final_headers = prepare_mcp_headers(headers, authorization)
|
|
|
|
tools = []
|
|
async with client_wrapper(endpoint, final_headers) as session:
|
|
tools_result = await session.list_tools()
|
|
for tool in tools_result.tools:
|
|
tools.append(
|
|
ToolDef(
|
|
name=tool.name,
|
|
description=tool.description,
|
|
input_schema=tool.inputSchema,
|
|
output_schema=getattr(tool, "outputSchema", None),
|
|
metadata={
|
|
"endpoint": endpoint,
|
|
},
|
|
)
|
|
)
|
|
return ListToolDefsResponse(data=tools)
|
|
|
|
|
|
async def invoke_mcp_tool(
|
|
endpoint: str,
|
|
tool_name: str,
|
|
kwargs: dict[str, Any],
|
|
headers: dict[str, str] | None = None,
|
|
authorization: str | None = None,
|
|
) -> ToolInvocationResult:
|
|
"""Invoke an MCP tool with the given arguments.
|
|
|
|
Args:
|
|
endpoint: MCP server endpoint URL
|
|
tool_name: Name of the tool to invoke
|
|
kwargs: Tool invocation arguments
|
|
headers: Optional base headers to include
|
|
authorization: Optional OAuth access token (just the token, not "Bearer <token>")
|
|
|
|
Returns:
|
|
Tool invocation result with content and error information
|
|
|
|
Raises:
|
|
ValueError: If Authorization header is found in the headers parameter
|
|
"""
|
|
# Prepare headers with authorization handling
|
|
final_headers = prepare_mcp_headers(headers, authorization)
|
|
|
|
async with client_wrapper(endpoint, final_headers) as session:
|
|
result = await session.call_tool(tool_name, kwargs)
|
|
|
|
content: list[InterleavedContentItem] = []
|
|
for item in result.content:
|
|
if isinstance(item, mcp_types.TextContent):
|
|
content.append(TextContentItem(text=item.text))
|
|
elif isinstance(item, mcp_types.ImageContent):
|
|
content.append(ImageContentItem(image=_URLOrData(data=item.data)))
|
|
elif isinstance(item, mcp_types.EmbeddedResource):
|
|
logger.warning(f"EmbeddedResource is not supported: {item}")
|
|
else:
|
|
raise ValueError(f"Unknown content type: {type(item)}")
|
|
return ToolInvocationResult(
|
|
content=content,
|
|
error_code=1 if result.isError else 0,
|
|
)
|