diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 2f62b0ba1..29649495c 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -76,8 +76,8 @@ class ToolInvocationResult(BaseModel): class ToolStore(Protocol): - def get_tool(self, tool_name: str) -> Tool: ... - def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ... + async def get_tool(self, tool_name: str) -> Tool: ... + async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ... class ListToolGroupsResponse(BaseModel): diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index aeb2b997a..def7048c0 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -236,6 +236,10 @@ class AuthenticationConfig(BaseModel): ) +class AuthenticationRequiredError(Exception): + pass + + class QuotaPeriod(str, Enum): DAY = "day" diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 34e041cbe..21b49a975 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -261,9 +261,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): raise ValueError("Client not initialized") # Create headers with provider data if available - headers = {} + headers = options.headers or {} if self.provider_data: - headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data) + keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"] + if all(key not in headers for key in keys): + headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data) # Use context manager for provider data with request_provider_data_context(headers): diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 7069390cf..d70f06691 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -28,7 +28,7 @@ from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError -from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig +from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import ( PROVIDER_DATA_VAR, @@ -122,6 +122,8 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}") elif isinstance(exc, NotImplementedError): return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}") + elif isinstance(exc, AuthenticationRequiredError): + return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}") else: return HTTPException( status_code=500, diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index b9194810e..277914df2 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -80,8 +80,9 @@ def available_providers() -> list[ProviderSpec]: adapter=AdapterSpec( adapter_type="model-context-protocol", module="llama_stack.providers.remote.tool_runtime.model_context_protocol", - config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig", + config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig", pip_packages=["mcp"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator", ), ), ] diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py index fb1f558e5..051a880a7 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py @@ -4,18 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel - -from .config import ModelContextProtocolConfig +from .config import MCPProviderConfig -class ModelContextProtocolToolProviderDataValidator(BaseModel): - api_key: str - - -async def get_adapter_impl(config: ModelContextProtocolConfig, _deps): +async def get_adapter_impl(config: MCPProviderConfig, _deps): from .model_context_protocol import ModelContextProtocolToolRuntimeImpl - impl = ModelContextProtocolToolRuntimeImpl(config) + impl = ModelContextProtocolToolRuntimeImpl(config, _deps) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py index d509074fc..d400159b2 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py @@ -9,7 +9,12 @@ from typing import Any from pydantic import BaseModel -class ModelContextProtocolConfig(BaseModel): +class MCPProviderDataValidator(BaseModel): + # mcp_endpoint => list of headers to send + mcp_headers: dict[str, list[str]] | None = None + + +class MCPProviderConfig(BaseModel): @classmethod def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return {} diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 142730e89..340e90ca1 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -4,13 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from contextlib import asynccontextmanager from typing import Any from urllib.parse import urlparse +import exceptiongroup +import httpx from mcp import ClientSession +from mcp import types as mcp_types from mcp.client.sse import sse_client -from llama_stack.apis.common.content_types import URL +from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem +from llama_stack.apis.datatypes import Api from llama_stack.apis.tools import ( ListToolDefsResponse, ToolDef, @@ -18,13 +23,36 @@ from llama_stack.apis.tools import ( ToolParameter, ToolRuntime, ) +from llama_stack.distribution.datatypes import AuthenticationRequiredError +from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ToolsProtocolPrivate -from .config import ModelContextProtocolConfig +from .config import MCPProviderConfig + +logger = get_logger(__name__, category="tools") -class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): - def __init__(self, config: ModelContextProtocolConfig): +@asynccontextmanager +async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): + try: + async with sse_client(endpoint, headers=headers) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + yield session + except BaseException as e: + if isinstance(e, exceptiongroup.BaseExceptionGroup): + for exc in e.exceptions: + if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401: + raise AuthenticationRequiredError(exc) from exc + elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401: + raise AuthenticationRequiredError(e) from e + + raise + + +class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): + def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]): self.config = config async def initialize(self): @@ -33,34 +61,34 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): async def list_runtime_tools( self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None ) -> ListToolDefsResponse: + # this endpoint should be retrieved by getting the tool group right? if mcp_endpoint is None: raise ValueError("mcp_endpoint is required") + headers = await self.get_headers_from_request(mcp_endpoint.uri) tools = [] - async with sse_client(mcp_endpoint.uri) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - tools_result = await session.list_tools() - for tool in tools_result.tools: - parameters = [] - for param_name, param_schema in tool.inputSchema.get("properties", {}).items(): - parameters.append( - ToolParameter( - name=param_name, - parameter_type=param_schema.get("type", "string"), - description=param_schema.get("description", ""), - ) - ) - tools.append( - ToolDef( - name=tool.name, - description=tool.description, - parameters=parameters, - metadata={ - "endpoint": mcp_endpoint.uri, - }, + async with sse_client_wrapper(mcp_endpoint.uri, headers) as session: + tools_result = await session.list_tools() + for tool in tools_result.tools: + parameters = [] + for param_name, param_schema in tool.inputSchema.get("properties", {}).items(): + parameters.append( + ToolParameter( + name=param_name, + parameter_type=param_schema.get("type", "string"), + description=param_schema.get("description", ""), ) ) + tools.append( + ToolDef( + name=tool.name, + description=tool.description, + parameters=parameters, + metadata={ + "endpoint": mcp_endpoint.uri, + }, + ) + ) return ListToolDefsResponse(data=tools) async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: @@ -71,12 +99,39 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): if urlparse(endpoint).scheme not in ("http", "https"): raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") - async with sse_client(endpoint) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - result = await session.call_tool(tool.identifier, kwargs) + headers = await self.get_headers_from_request(endpoint) + async with sse_client_wrapper(endpoint, headers) as session: + result = await session.call_tool(tool.identifier, kwargs) + content = [] + for item in result.content: + if isinstance(item, mcp_types.TextContent): + content.append(TextContentItem(text=item.text)) + elif isinstance(item, mcp_types.ImageContent): + content.append(ImageContentItem(image=item.data)) + elif isinstance(item, mcp_types.EmbeddedResource): + logger.warning(f"EmbeddedResource is not supported: {item}") + else: + raise ValueError(f"Unknown content type: {type(item)}") return ToolInvocationResult( - content="\n".join([result.model_dump_json() for result in result.content]), + content=content, error_code=1 if result.isError else 0, ) + + async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]: + def canonicalize_uri(uri: str) -> str: + return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}" + + headers = {} + + provider_data = self.get_request_provider_data() + if provider_data and provider_data.mcp_headers: + for uri, values in provider_data.mcp_headers.items(): + if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri): + continue + for entry in values: + parts = entry.split(":") + if len(parts) == 2: + k, v = parts + headers[k.strip()] = v.strip() + return headers diff --git a/tests/integration/tool_runtime/test_mcp.py b/tests/integration/tool_runtime/test_mcp.py new file mode 100644 index 000000000..e553c6a0b --- /dev/null +++ b/tests/integration/tool_runtime/test_mcp.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import socket +import threading +import time + +import httpx +import mcp.types as types +import pytest +import uvicorn +from llama_stack_client import Agent +from mcp.server.fastmcp import Context, FastMCP +from mcp.server.sse import SseServerTransport +from starlette.applications import Starlette +from starlette.exceptions import HTTPException +from starlette.responses import Response +from starlette.routing import Mount, Route + +from llama_stack import LlamaStackAsLibraryClient +from llama_stack.distribution.datatypes import AuthenticationRequiredError + +AUTH_TOKEN = "test-token" + + +@pytest.fixture(scope="module") +def mcp_server(): + server = FastMCP("FastMCP Test Server") + + @server.tool() + async def greet_everyone( + url: str, ctx: Context + ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + return [types.TextContent(type="text", text="Hello, world!")] + + sse = SseServerTransport("/messages/") + + async def handle_sse(request): + auth_header = request.headers.get("Authorization") + auth_token = None + if auth_header and auth_header.startswith("Bearer "): + auth_token = auth_header.split(" ")[1] + + if auth_token != AUTH_TOKEN: + raise HTTPException(status_code=401, detail="Unauthorized") + + async with sse.connect_sse(request.scope, request.receive, request._send) as streams: + await server._mcp_server.run( + streams[0], + streams[1], + server._mcp_server.create_initialization_options(), + ) + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ], + ) + + def get_open_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + port = get_open_port() + + config = uvicorn.Config(app, host="0.0.0.0", port=port) + server_instance = uvicorn.Server(config) + app.state.uvicorn_server = server_instance + + def run_server(): + server_instance.run() + + # Start the server in a new thread + server_thread = threading.Thread(target=run_server, daemon=True) + server_thread.start() + + # Polling until the server is ready + timeout = 10 + start_time = time.time() + + while time.time() - start_time < timeout: + try: + response = httpx.get(f"http://localhost:{port}/sse") + if response.status_code == 401: + break + except httpx.RequestError: + pass + time.sleep(0.1) + + yield port + + # Tell server to exit + server_instance.should_exit = True + server_thread.join(timeout=5) + + +def test_mcp_invocation(llama_stack_client, mcp_server): + port = mcp_server + test_toolgroup_id = "remote::mcptest" + + # registering itself should fail since it requires listing tools + with pytest.raises(Exception, match="Unauthorized"): + llama_stack_client.toolgroups.register( + toolgroup_id=test_toolgroup_id, + provider_id="model-context-protocol", + mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"), + ) + + provider_data = { + "mcp_headers": { + f"http://localhost:{port}/sse": [ + f"Authorization: Bearer {AUTH_TOKEN}", + ], + }, + } + auth_headers = { + "X-LlamaStack-Provider-Data": json.dumps(provider_data), + } + + try: + llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id, extra_headers=auth_headers) + except Exception as e: + # An error is OK since the toolgroup may not exist + print(f"Error unregistering toolgroup: {e}") + + llama_stack_client.toolgroups.register( + toolgroup_id=test_toolgroup_id, + provider_id="model-context-protocol", + mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"), + extra_headers=auth_headers, + ) + response = llama_stack_client.tools.list( + toolgroup_id=test_toolgroup_id, + extra_headers=auth_headers, + ) + assert len(response) == 1 + assert response[0].identifier == "greet_everyone" + assert response[0].type == "tool" + assert len(response[0].parameters) == 1 + p = response[0].parameters[0] + assert p.name == "url" + assert p.parameter_type == "string" + assert p.required + + response = llama_stack_client.tool_runtime.invoke_tool( + tool_name=response[0].identifier, + kwargs=dict(url="https://www.google.com"), + extra_headers=auth_headers, + ) + content = response.content + assert len(content) == 1 + assert content[0].type == "text" + assert content[0].text == "Hello, world!" + + models = llama_stack_client.models.list() + model_id = models[0].identifier + print(f"Using model: {model_id}") + agent = Agent( + client=llama_stack_client, + model=model_id, + instructions="You are a helpful assistant.", + tools=[test_toolgroup_id], + ) + session_id = agent.create_session("test-session") + response = agent.create_turn( + session_id=session_id, + messages=[ + { + "role": "user", + "content": "Yo. Use tools.", + } + ], + stream=False, + extra_headers=auth_headers, + ) + steps = response.steps + first = steps[0] + assert first.step_type == "inference" + assert len(first.api_model_response.tool_calls) == 1 + tool_call = first.api_model_response.tool_calls[0] + assert tool_call.tool_name == "greet_everyone" + + second = steps[1] + assert second.step_type == "tool_execution" + tool_response_content = second.tool_responses[0].content + assert len(tool_response_content) == 1 + assert tool_response_content[0].type == "text" + assert tool_response_content[0].text == "Hello, world!" + + third = steps[2] + assert third.step_type == "inference" + assert len(third.api_model_response.tool_calls) == 0 + + # when streaming, we currently don't check auth headers upfront and fail the request + # early. but we should at least be generating a 401 later in the process. + response = agent.create_turn( + session_id=session_id, + messages=[ + { + "role": "user", + "content": "Yo. Use tools.", + } + ], + stream=True, + ) + if isinstance(llama_stack_client, LlamaStackAsLibraryClient): + with pytest.raises(AuthenticationRequiredError): + for _ in response: + pass + else: + error_chunks = [chunk for chunk in response if "error" in chunk.model_dump()] + assert len(error_chunks) == 1 + chunk = error_chunks[0].model_dump() + assert "Unauthorized" in chunk["error"]["message"]