fix: MCP authorization parameter implementation (#4052)

# What does this PR do?
Adding a user-facing `authorization ` parameter to MCP tool definitions
that allows users to explicitly configure credentials per MCP server,
addressing GitHub Issue #4034 in a secure manner.


## Test Plan
tests/integration/responses/test_mcp_authentication.py

---------

Co-authored-by: Omar Abdelwahab <omara@fb.com>
Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
This commit is contained in:
Omar Abdelwahab 2025-11-14 08:54:42 -08:00 committed by GitHub
parent dc49ad3f89
commit eb545034ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 5205 additions and 62 deletions

View file

@ -30,6 +30,40 @@ from llama_stack_api import (
logger = get_logger(__name__, category="tools")
def prepare_mcp_headers(base_headers: dict[str, str] | None, authorization: str | None) -> dict[str, str]:
"""
Prepare headers for MCP requests with authorization support.
Args:
base_headers: Base headers dictionary (can be None)
authorization: OAuth access token (without "Bearer " prefix)
Returns:
Headers dictionary with Authorization header if token provided
Raises:
ValueError: If Authorization header is specified in the headers dict (security risk)
"""
headers = dict(base_headers or {})
# Security check: reject any Authorization header in the headers dict
# Users must use the authorization parameter instead to avoid security risks
existing_keys_lower = {k.lower() for k in headers.keys()}
if "authorization" in existing_keys_lower:
raise ValueError(
"For security reasons, Authorization header cannot be passed via 'headers'. "
"Please use the 'authorization' parameter instead."
)
# Add Authorization header if token provided
if authorization:
# OAuth access token - add "Bearer " prefix
headers["Authorization"] = f"Bearer {authorization}"
return headers
protocol_cache = TTLDict(ttl_seconds=3600)
@ -112,9 +146,29 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
raise
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
async def list_mcp_tools(
endpoint: str,
headers: dict[str, str] | None = None,
authorization: str | None = None,
) -> ListToolDefsResponse:
"""List tools available from an MCP server.
Args:
endpoint: MCP server endpoint URL
headers: Optional base headers to include
authorization: Optional OAuth access token (just the token, not "Bearer <token>")
Returns:
List of tool definitions from the MCP server
Raises:
ValueError: If Authorization is found in the headers parameter
"""
# Prepare headers with authorization handling
final_headers = prepare_mcp_headers(headers, authorization)
tools = []
async with client_wrapper(endpoint, headers) as session:
async with client_wrapper(endpoint, final_headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
tools.append(
@ -132,9 +186,31 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
async def invoke_mcp_tool(
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
endpoint: str,
tool_name: str,
kwargs: dict[str, Any],
headers: dict[str, str] | None = None,
authorization: str | None = None,
) -> ToolInvocationResult:
async with client_wrapper(endpoint, headers) as session:
"""Invoke an MCP tool with the given arguments.
Args:
endpoint: MCP server endpoint URL
tool_name: Name of the tool to invoke
kwargs: Tool invocation arguments
headers: Optional base headers to include
authorization: Optional OAuth access token (just the token, not "Bearer <token>")
Returns:
Tool invocation result with content and error information
Raises:
ValueError: If Authorization header is found in the headers parameter
"""
# Prepare headers with authorization handling
final_headers = prepare_mcp_headers(headers, authorization)
async with client_wrapper(endpoint, final_headers) as session:
result = await session.call_tool(tool_name, kwargs)
content: list[InterleavedContentItem] = []