From 5937d94da536c2ac6eccd192c944f7af9e642c17 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 22 May 2025 20:21:47 -0700 Subject: [PATCH] feat: enable MCP execution in Response implementation --- docs/_static/llama-stack-spec.html | 128 +++++++- docs/_static/llama-stack-spec.yaml | 81 ++++- llama_stack/apis/agents/openai_responses.py | 38 ++- .../agents/meta_reference/openai_responses.py | 298 +++++++++++++----- .../model_context_protocol.py | 81 +---- llama_stack/providers/utils/tools/mcp.py | 103 ++++++ tests/common/mcp.py | 116 +++++++ .../fixtures/test_cases/responses.yaml | 14 + .../openai_api/test_responses.py | 43 +++ 9 files changed, 728 insertions(+), 174 deletions(-) create mode 100644 llama_stack/providers/utils/tools/mcp.py create mode 100644 tests/common/mcp.py diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 7a6e82001..99ae1c038 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -7218,15 +7218,15 @@ "OpenAIResponseOutputMessageFunctionToolCall": { "type": "object", "properties": { - "arguments": { - "type": "string" - }, "call_id": { "type": "string" }, "name": { "type": "string" }, + "arguments": { + "type": "string" + }, "type": { "type": "string", "const": "function_call", @@ -7241,12 +7241,10 @@ }, "additionalProperties": false, "required": [ - "arguments", "call_id", "name", - "type", - "id", - "status" + "arguments", + "type" ], "title": "OpenAIResponseOutputMessageFunctionToolCall" }, @@ -7412,6 +7410,12 @@ }, { "$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" + }, + { + "$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPCall" + }, + { + "$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" } ], "discriminator": { @@ -7419,10 +7423,118 @@ "mapping": { "message": "#/components/schemas/OpenAIResponseMessage", "web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall", - "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" + "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall", + "mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall", + "mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" } } }, + "OpenAIResponseOutputMessageMCPCall": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "mcp_call", + "default": "mcp_call" + }, + "arguments": { + "type": "string" + }, + "name": { + "type": "string" + }, + "server_label": { + "type": "string" + }, + "error": { + "type": "string" + }, + "output": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "id", + "type", + "arguments", + "name", + "server_label" + ], + "title": "OpenAIResponseOutputMessageMCPCall" + }, + "OpenAIResponseOutputMessageMCPListTools": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "mcp_list_tools", + "default": "mcp_list_tools" + }, + "server_label": { + "type": "string" + }, + "tools": { + "type": "array", + "items": { + "type": "object", + "properties": { + "input_schema": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "name": { + "type": "string" + }, + "description": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "input_schema", + "name" + ], + "title": "MCPListToolsTool" + } + } + }, + "additionalProperties": false, + "required": [ + "id", + "type", + "server_label", + "tools" + ], + "title": "OpenAIResponseOutputMessageMCPListTools" + }, "OpenAIResponseObjectStream": { "oneOf": [ { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 74c4852d4..4e4f09eb0 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -5075,12 +5075,12 @@ components: "OpenAIResponseOutputMessageFunctionToolCall": type: object properties: - arguments: - type: string call_id: type: string name: type: string + arguments: + type: string type: type: string const: function_call @@ -5091,12 +5091,10 @@ components: type: string additionalProperties: false required: - - arguments - call_id - name + - arguments - type - - id - - status title: >- OpenAIResponseOutputMessageFunctionToolCall "OpenAIResponseOutputMessageWebSearchToolCall": @@ -5214,12 +5212,85 @@ components: - $ref: '#/components/schemas/OpenAIResponseMessage' - $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' + - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' + - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' discriminator: propertyName: type mapping: message: '#/components/schemas/OpenAIResponseMessage' web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' + mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' + mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' + OpenAIResponseOutputMessageMCPCall: + type: object + properties: + id: + type: string + type: + type: string + const: mcp_call + default: mcp_call + arguments: + type: string + name: + type: string + server_label: + type: string + error: + type: string + output: + type: string + additionalProperties: false + required: + - id + - type + - arguments + - name + - server_label + title: OpenAIResponseOutputMessageMCPCall + OpenAIResponseOutputMessageMCPListTools: + type: object + properties: + id: + type: string + type: + type: string + const: mcp_list_tools + default: mcp_list_tools + server_label: + type: string + tools: + type: array + items: + type: object + properties: + input_schema: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + name: + type: string + description: + type: string + additionalProperties: false + required: + - input_schema + - name + title: MCPListToolsTool + additionalProperties: false + required: + - id + - type + - server_label + - tools + title: OpenAIResponseOutputMessageMCPListTools OpenAIResponseObjectStream: oneOf: - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index 7740b5643..675e8f3ff 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -10,6 +10,9 @@ from pydantic import BaseModel, Field from llama_stack.schema_utils import json_schema_type, register_schema +# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably +# take their YAML and generate this file automatically. Their YAML is available. + @json_schema_type class OpenAIResponseError(BaseModel): @@ -79,16 +82,45 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel): @json_schema_type class OpenAIResponseOutputMessageFunctionToolCall(BaseModel): - arguments: str call_id: str name: str + arguments: str type: Literal["function_call"] = "function_call" + id: str | None = None + status: str | None = None + + +@json_schema_type +class OpenAIResponseOutputMessageMCPCall(BaseModel): id: str - status: str + type: Literal["mcp_call"] = "mcp_call" + arguments: str + name: str + server_label: str + error: str | None = None + output: str | None = None + + +class MCPListToolsTool(BaseModel): + input_schema: dict[str, Any] + name: str + description: str | None = None + + +@json_schema_type +class OpenAIResponseOutputMessageMCPListTools(BaseModel): + id: str + type: Literal["mcp_list_tools"] = "mcp_list_tools" + server_label: str + tools: list[MCPListToolsTool] OpenAIResponseOutput = Annotated[ - OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseMessage + | OpenAIResponseOutputMessageWebSearchToolCall + | OpenAIResponseOutputMessageFunctionToolCall + | OpenAIResponseOutputMessageMCPCall + | OpenAIResponseOutputMessageMCPListTools, Field(discriminator="type"), ] register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput") diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 9acda3b8c..a7c6d684e 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -14,6 +14,7 @@ from pydantic import BaseModel from llama_stack.apis.agents import Order from llama_stack.apis.agents.openai_responses import ( + AllowedToolsFilter, ListOpenAIResponseInputItem, ListOpenAIResponseObject, OpenAIResponseInput, @@ -22,7 +23,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseInputMessageContentImage, OpenAIResponseInputMessageContentText, OpenAIResponseInputTool, - OpenAIResponseInputToolFunction, + OpenAIResponseInputToolMCP, OpenAIResponseMessage, OpenAIResponseObject, OpenAIResponseObjectStream, @@ -51,11 +52,12 @@ from llama_stack.apis.inference.inference import ( OpenAIToolMessageParam, OpenAIUserMessageParam, ) -from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime +from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool from llama_stack.providers.utils.responses.responses_store import ResponsesStore +from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools logger = get_logger(name=__name__, category="openai_responses") @@ -168,6 +170,15 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel): response: OpenAIResponseObject +class ChatCompletionContext(BaseModel): + model: str + messages: list[OpenAIMessageParam] + tools: list[ChatCompletionToolParam] | None = None + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] + stream: bool + temperature: float | None + + class OpenAIResponsesImpl: def __init__( self, @@ -255,13 +266,32 @@ class OpenAIResponsesImpl: temperature: float | None = None, tools: list[OpenAIResponseInputTool] | None = None, ): + output_messages: list[OpenAIResponseOutput] = [] + stream = False if stream is None else stream + # Huge TODO: we need to run this in a loop, until morale improves + + # Create context to run "chat completion" input = await self._prepend_previous_response(input, previous_response_id) messages = await _convert_response_input_to_chat_messages(input) await self._prepend_instructions(messages, instructions) - chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None + chat_tools, mcp_tool_to_server, mcp_list_message = ( + await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}) + ) + if mcp_list_message: + output_messages.append(mcp_list_message) + ctx = ChatCompletionContext( + model=model, + messages=messages, + tools=chat_tools, + mcp_tool_to_server=mcp_tool_to_server, + stream=stream, + temperature=temperature, + ) + + # Run inference chat_response = await self.inference_api.openai_chat_completion( model=model, messages=messages, @@ -270,6 +300,7 @@ class OpenAIResponsesImpl: temperature=temperature, ) + # Collect output if stream: # TODO: refactor this into a separate method that handles streaming chat_response_id = "" @@ -328,11 +359,11 @@ class OpenAIResponsesImpl: # dump and reload to map to our pydantic types chat_response = OpenAIChatCompletion(**chat_response.model_dump()) - output_messages: list[OpenAIResponseOutput] = [] + # Execute tool calls if any for choice in chat_response.choices: if choice.message.tool_calls and tools: # Assume if the first tool is a function, all tools are functions - if isinstance(tools[0], OpenAIResponseInputToolFunction): + if tools[0].type == "function": for tool_call in choice.message.tool_calls: output_messages.append( OpenAIResponseOutputMessageFunctionToolCall( @@ -344,11 +375,12 @@ class OpenAIResponsesImpl: ) ) else: - output_messages.extend( - await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature) - ) + tool_messages = await self._execute_tool_and_return_final_output(choice, ctx) + output_messages.extend(tool_messages) else: output_messages.append(await _convert_chat_choice_to_response_message(choice)) + + # Create response object response = OpenAIResponseObject( created_at=chat_response.created, id=f"resp-{uuid.uuid4()}", @@ -359,9 +391,8 @@ class OpenAIResponsesImpl: ) logger.debug(f"OpenAI Responses response: {response}") + # Store response if requested if store: - # Store in kvstore - new_input_id = f"msg_{uuid.uuid4()}" if isinstance(input, str): # synthesize a message from the input string @@ -403,7 +434,36 @@ class OpenAIResponsesImpl: async def _convert_response_tools_to_chat_tools( self, tools: list[OpenAIResponseInputTool] - ) -> list[ChatCompletionToolParam]: + ) -> tuple[ + list[ChatCompletionToolParam], + dict[str, OpenAIResponseInputToolMCP], + OpenAIResponseOutput | None, + ]: + from llama_stack.apis.agents.openai_responses import ( + MCPListToolsTool, + OpenAIResponseOutputMessageMCPListTools, + ) + from llama_stack.apis.tools.tools import Tool + + mcp_tool_to_server = {} + + def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam: + tool_def = ToolDefinition( + tool_name=tool_name, + description=tool.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool.parameters + }, + ) + return convert_tooldef_to_openai_tool(tool_def) + + mcp_list_message = None chat_tools: list[ChatCompletionToolParam] = [] for input_tool in tools: # TODO: Handle other tool types @@ -412,91 +472,95 @@ class OpenAIResponsesImpl: elif input_tool.type == "web_search": tool_name = "web_search" tool = await self.tool_groups_api.get_tool(tool_name) - tool_def = ToolDefinition( - tool_name=tool_name, - description=tool.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - for param in tool.parameters - }, + if not tool: + raise ValueError(f"Tool {tool_name} not found") + chat_tools.append(make_openai_tool(tool_name, tool)) + elif input_tool.type == "mcp": + always_allowed = None + never_allowed = None + if input_tool.allowed_tools: + if isinstance(input_tool.allowed_tools, list): + always_allowed = input_tool.allowed_tools + elif isinstance(input_tool.allowed_tools, AllowedToolsFilter): + always_allowed = input_tool.allowed_tools.always + never_allowed = input_tool.allowed_tools.never + + tool_defs = await list_mcp_tools( + endpoint=input_tool.server_url, + headers=convert_header_list_to_dict(input_tool.headers or []), ) - chat_tool = convert_tooldef_to_openai_tool(tool_def) - chat_tools.append(chat_tool) + + mcp_list_message = OpenAIResponseOutputMessageMCPListTools( + id=f"mcp_list_{uuid.uuid4()}", + status="completed", + server_label=input_tool.server_label, + tools=[], + ) + for t in tool_defs.data: + if never_allowed and t.name in never_allowed: + continue + if not always_allowed or t.name in always_allowed: + chat_tools.append(make_openai_tool(t.name, t)) + if t.name in mcp_tool_to_server: + raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}") + mcp_tool_to_server[t.name] = input_tool + mcp_list_message.tools.append( + MCPListToolsTool( + name=t.name, + description=t.description, + input_schema={ + "type": "object", + "properties": { + p.name: { + "type": p.parameter_type, + "description": p.description, + } + for p in t.parameters + }, + "required": [p.name for p in t.parameters if p.required], + }, + ) + ) else: raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}") - return chat_tools + return chat_tools, mcp_tool_to_server, mcp_list_message async def _execute_tool_and_return_final_output( self, - model_id: str, - stream: bool, choice: OpenAIChoice, - messages: list[OpenAIMessageParam], - temperature: float, + ctx: ChatCompletionContext, ) -> list[OpenAIResponseOutput]: output_messages: list[OpenAIResponseOutput] = [] - # If the choice is not an assistant message, we don't need to execute any tools if not isinstance(choice.message, OpenAIAssistantMessageParam): return output_messages - # If the assistant message doesn't have any tool calls, we don't need to execute any tools if not choice.message.tool_calls: return output_messages - # Copy the messages list to avoid mutating the original list - messages = messages.copy() + next_turn_messages = ctx.messages.copy() # Add the assistant message with tool_calls response to the messages list - messages.append(choice.message) + next_turn_messages.append(choice.message) for tool_call in choice.message.tool_calls: - tool_call_id = tool_call.id - function = tool_call.function - - # If for some reason the tool call doesn't have a function or id, we can't execute it - if not function or not tool_call_id: - continue - # TODO: telemetry spans for tool calls - result = await self._execute_tool_call(function) - - # Handle tool call failure - if not result: - output_messages.append( - OpenAIResponseOutputMessageWebSearchToolCall( - id=tool_call_id, - status="failed", - ) - ) - continue - - output_messages.append( - OpenAIResponseOutputMessageWebSearchToolCall( - id=tool_call_id, - status="completed", - ), - ) - - result_content = "" - # TODO: handle other result content types and lists - if isinstance(result.content, str): - result_content = result.content - messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id)) + tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx) + if tool_call_log: + output_messages.append(tool_call_log) + if further_input: + next_turn_messages.append(further_input) tool_results_chat_response = await self.inference_api.openai_chat_completion( - model=model_id, - messages=messages, - stream=stream, - temperature=temperature, + model=ctx.model, + messages=next_turn_messages, + stream=ctx.stream, + temperature=ctx.temperature, ) - # type cast to appease mypy + # type cast to appease mypy: this is needed because we don't handle streaming properly :) tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response) + + # Huge TODO: these are NOT the final outputs, we must keep the loop going tool_final_outputs = [ await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices ] @@ -506,15 +570,85 @@ class OpenAIResponsesImpl: async def _execute_tool_call( self, - function: OpenAIChatCompletionToolCallFunction, - ) -> ToolInvocationResult | None: - if not function.name: - return None - function_args = json.loads(function.arguments) if function.arguments else {} - logger.info(f"executing tool call: {function.name} with args: {function_args}") - result = await self.tool_runtime_api.invoke_tool( - tool_name=function.name, - kwargs=function_args, + tool_call: OpenAIChatCompletionToolCall, + ctx: ChatCompletionContext, + ) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]: + from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, ) - logger.debug(f"tool call {function.name} completed with result: {result}") - return result + + tool_call_id = tool_call.id + function = tool_call.function + + if not function or not tool_call_id or not function.name: + return None, None + + error_exc = None + try: + if function.name in ctx.mcp_tool_to_server: + mcp_tool = ctx.mcp_tool_to_server[function.name] + result = await invoke_mcp_tool( + endpoint=mcp_tool.server_url, + headers=convert_header_list_to_dict(mcp_tool.headers or []), + tool_name=function.name, + kwargs=json.loads(function.arguments) if function.arguments else {}, + ) + else: + result = await self.tool_runtime_api.invoke_tool( + tool_name=function.name, + kwargs=json.loads(function.arguments) if function.arguments else {}, + ) + except Exception as e: + error_exc = e + + if function.name in ctx.mcp_tool_to_server: + from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall + + message = OpenAIResponseOutputMessageMCPCall( + id=tool_call_id, + arguments=function.arguments, + name=function.name, + server_label=ctx.mcp_tool_to_server[function.name].server_label, + ) + if error_exc: + message.error = str(error_exc) + elif (result.error_code and result.error_code > 0) or result.error_message: + message.error = f"Error (code {result.error_code}): {result.error_message}" + elif result.content: + message.output = interleaved_content_as_str(result.content) + else: + if function.name == "web_search": + message = OpenAIResponseOutputMessageWebSearchToolCall( + id=tool_call_id, + status="completed", + ) + if error_exc or (result.error_code and result.error_code > 0) or result.error_message: + message.status = "failed" + else: + raise ValueError(f"Unknown tool {function.name} called") + + input_message = None + if result.content: + if isinstance(result.content, str): + content = result.content + elif isinstance(result.content, list): + from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem + + content = [] + for item in result.content: + if isinstance(item, TextContentItem): + part = OpenAIChatCompletionContentPartTextParam(text=item.text) + elif isinstance(item, ImageContentItem): + if item.image.data: + url = f"data:image;base64,{item.image.data}" + else: + url = item.image.url + part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url)) + else: + raise ValueError(f"Unknown result content type: {type(item)}") + content.append(part) + else: + raise ValueError(f"Unknown result content type: {type(result.content)}") + input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) + + return message, input_message diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 340e90ca1..3f0b9a188 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -4,53 +4,26 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from contextlib import asynccontextmanager from typing import Any from urllib.parse import urlparse -import exceptiongroup -import httpx -from mcp import ClientSession -from mcp import types as mcp_types -from mcp.client.sse import sse_client - -from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem +from llama_stack.apis.common.content_types import URL from llama_stack.apis.datatypes import Api from llama_stack.apis.tools import ( ListToolDefsResponse, - ToolDef, ToolInvocationResult, - ToolParameter, ToolRuntime, ) -from llama_stack.distribution.datatypes import AuthenticationRequiredError from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools from .config import MCPProviderConfig logger = get_logger(__name__, category="tools") -@asynccontextmanager -async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): - try: - async with sse_client(endpoint, headers=headers) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - yield session - except BaseException as e: - if isinstance(e, exceptiongroup.BaseExceptionGroup): - for exc in e.exceptions: - if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401: - raise AuthenticationRequiredError(exc) from exc - elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401: - raise AuthenticationRequiredError(e) from e - - raise - - class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]): self.config = config @@ -64,32 +37,8 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee # this endpoint should be retrieved by getting the tool group right? if mcp_endpoint is None: raise ValueError("mcp_endpoint is required") - headers = await self.get_headers_from_request(mcp_endpoint.uri) - tools = [] - async with sse_client_wrapper(mcp_endpoint.uri, headers) as session: - tools_result = await session.list_tools() - for tool in tools_result.tools: - parameters = [] - for param_name, param_schema in tool.inputSchema.get("properties", {}).items(): - parameters.append( - ToolParameter( - name=param_name, - parameter_type=param_schema.get("type", "string"), - description=param_schema.get("description", ""), - ) - ) - tools.append( - ToolDef( - name=tool.name, - description=tool.description, - parameters=parameters, - metadata={ - "endpoint": mcp_endpoint.uri, - }, - ) - ) - return ListToolDefsResponse(data=tools) + return await list_mcp_tools(mcp_endpoint.uri, headers) async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: tool = await self.tool_store.get_tool(tool_name) @@ -100,23 +49,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") headers = await self.get_headers_from_request(endpoint) - async with sse_client_wrapper(endpoint, headers) as session: - result = await session.call_tool(tool.identifier, kwargs) - - content = [] - for item in result.content: - if isinstance(item, mcp_types.TextContent): - content.append(TextContentItem(text=item.text)) - elif isinstance(item, mcp_types.ImageContent): - content.append(ImageContentItem(image=item.data)) - elif isinstance(item, mcp_types.EmbeddedResource): - logger.warning(f"EmbeddedResource is not supported: {item}") - else: - raise ValueError(f"Unknown content type: {type(item)}") - return ToolInvocationResult( - content=content, - error_code=1 if result.isError else 0, - ) + return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs) async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]: def canonicalize_uri(uri: str) -> str: @@ -129,9 +62,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee for uri, values in provider_data.mcp_headers.items(): if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri): continue - for entry in values: - parts = entry.split(":") - if len(parts) == 2: - k, v = parts - headers[k.strip()] = v.strip() + headers.update(convert_header_list_to_dict(values)) return headers diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py new file mode 100644 index 000000000..afa4df766 --- /dev/null +++ b/llama_stack/providers/utils/tools/mcp.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from contextlib import asynccontextmanager +from typing import Any + +import exceptiongroup +import httpx +from mcp import ClientSession +from mcp import types as mcp_types +from mcp.client.sse import sse_client + +from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem +from llama_stack.apis.tools import ( + ListToolDefsResponse, + ToolDef, + ToolInvocationResult, + ToolParameter, +) +from llama_stack.distribution.datatypes import AuthenticationRequiredError +from llama_stack.log import get_logger + +logger = get_logger(__name__, category="tools") + + +@asynccontextmanager +async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): + try: + async with sse_client(endpoint, headers=headers) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + yield session + except BaseException as e: + if isinstance(e, exceptiongroup.BaseExceptionGroup): + for exc in e.exceptions: + if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401: + raise AuthenticationRequiredError(exc) from exc + elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401: + raise AuthenticationRequiredError(e) from e + + raise + + +def convert_header_list_to_dict(header_list: list[str]) -> dict[str, str]: + headers = {} + for header in header_list: + parts = header.split(":") + if len(parts) == 2: + k, v = parts + headers[k.strip()] = v.strip() + return headers + + +async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: + tools = [] + async with sse_client_wrapper(endpoint, headers) as session: + tools_result = await session.list_tools() + for tool in tools_result.tools: + parameters = [] + for param_name, param_schema in tool.inputSchema.get("properties", {}).items(): + parameters.append( + ToolParameter( + name=param_name, + parameter_type=param_schema.get("type", "string"), + description=param_schema.get("description", ""), + ) + ) + tools.append( + ToolDef( + name=tool.name, + description=tool.description, + parameters=parameters, + metadata={ + "endpoint": endpoint, + }, + ) + ) + return ListToolDefsResponse(data=tools) + + +async def invoke_mcp_tool( + endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any] +) -> ToolInvocationResult: + async with sse_client_wrapper(endpoint, headers) as session: + result = await session.call_tool(tool_name, kwargs) + + content: list[InterleavedContentItem] = [] + for item in result.content: + if isinstance(item, mcp_types.TextContent): + content.append(TextContentItem(text=item.text)) + elif isinstance(item, mcp_types.ImageContent): + content.append(ImageContentItem(image=item.data)) + elif isinstance(item, mcp_types.EmbeddedResource): + logger.warning(f"EmbeddedResource is not supported: {item}") + else: + raise ValueError(f"Unknown content type: {type(item)}") + return ToolInvocationResult( + content=content, + error_code=1 if result.isError else 0, + ) diff --git a/tests/common/mcp.py b/tests/common/mcp.py new file mode 100644 index 000000000..f602cbff2 --- /dev/null +++ b/tests/common/mcp.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# we want the mcp server to be authenticated OR not, depends +from contextlib import contextmanager + + +@contextmanager +def make_mcp_server(required_auth_token: str | None = None): + import threading + import time + + import httpx + import uvicorn + from mcp import types + from mcp.server.fastmcp import Context, FastMCP + from mcp.server.sse import SseServerTransport + from starlette.applications import Starlette + from starlette.responses import Response + from starlette.routing import Mount, Route + + server = FastMCP("FastMCP Test Server") + + @server.tool() + async def greet_everyone( + url: str, ctx: Context + ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + return [types.TextContent(type="text", text="Hello, world!")] + + @server.tool() + async def get_boiling_point(liquid_name: str, celcius: bool = True) -> int: + """ + Returns the boiling point of a liquid in Celcius or Fahrenheit. + + :param liquid_name: The name of the liquid + :param celcius: Whether to return the boiling point in Celcius + :return: The boiling point of the liquid in Celcius or Fahrenheit + """ + if liquid_name.lower() == "polyjuice": + if celcius: + return -100 + else: + return -212 + else: + return -1 + + sse = SseServerTransport("/messages/") + + async def handle_sse(request): + from starlette.exceptions import HTTPException + + auth_header = request.headers.get("Authorization") + auth_token = None + if auth_header and auth_header.startswith("Bearer "): + auth_token = auth_header.split(" ")[1] + + if required_auth_token and auth_token != required_auth_token: + raise HTTPException(status_code=401, detail="Unauthorized") + + async with sse.connect_sse(request.scope, request.receive, request._send) as streams: + await server._mcp_server.run( + streams[0], + streams[1], + server._mcp_server.create_initialization_options(), + ) + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ], + ) + + def get_open_port(): + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + port = get_open_port() + + config = uvicorn.Config(app, host="0.0.0.0", port=port) + server_instance = uvicorn.Server(config) + app.state.uvicorn_server = server_instance + + def run_server(): + server_instance.run() + + # Start the server in a new thread + server_thread = threading.Thread(target=run_server, daemon=True) + server_thread.start() + + # Polling until the server is ready + timeout = 10 + start_time = time.time() + + server_url = f"http://localhost:{port}/sse" + while time.time() - start_time < timeout: + try: + response = httpx.get(server_url) + if response.status_code in [200, 401]: + break + except httpx.RequestError: + pass + time.sleep(0.1) + + yield {"server_url": server_url} + + # Tell server to exit + server_instance.should_exit = True + server_thread.join(timeout=5) diff --git a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml index 262d82526..a50abef44 100644 --- a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml +++ b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml @@ -31,6 +31,20 @@ test_response_web_search: search_context_size: "low" output: "128" +test_response_mcp_tool: + test_name: test_response_mcp_tool + test_params: + case: + - case_id: "boiling_point_tool" + input: "What is the boiling point of polyjuice?" + tools: + - type: mcp + server_label: "localmcp" + server_url: "" + headers: + Authorization: "Bearer test-token" + output: "Hello, world!" + test_response_custom_tool: test_name: test_response_custom_tool test_params: diff --git a/tests/verifications/openai_api/test_responses.py b/tests/verifications/openai_api/test_responses.py index e279b9b38..b1c04e2f3 100644 --- a/tests/verifications/openai_api/test_responses.py +++ b/tests/verifications/openai_api/test_responses.py @@ -4,9 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json import pytest +from tests.common.mcp import make_mcp_server from tests.verifications.openai_api.fixtures.fixtures import ( case_id_generator, get_base_test_name, @@ -124,6 +126,47 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid assert case["output"].lower() in response.output_text.lower().strip() +@pytest.mark.parametrize( + "case", + responses_test_cases["test_response_mcp_tool"]["test_params"]["case"], + ids=case_id_generator, +) +def test_response_non_streaming_mcp_tool(request, openai_client, model, provider, verification_config, case): + test_name_base = get_base_test_name(request) + if should_skip_test(verification_config, provider, model, test_name_base): + pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") + + with make_mcp_server() as mcp_server_info: + tools = case["tools"] + for tool in tools: + if tool["type"] == "mcp": + tool["server_url"] = mcp_server_info["server_url"] + + response = openai_client.responses.create( + model=model, + input=case["input"], + tools=tools, + stream=False, + ) + assert len(response.output) >= 3 + list_tools = response.output[0] + assert list_tools.type == "mcp_list_tools" + assert list_tools.server_label == "localmcp" + assert len(list_tools.tools) == 2 + assert {t["name"] for t in list_tools.tools} == {"get_boiling_point", "greet_everyone"} + + call = response.output[1] + assert call.type == "mcp_call" + assert call.name == "get_boiling_point" + assert json.loads(call.arguments) == {"liquid_name": "polyjuice", "celcius": True} + assert call.error is None + assert "-100" in call.output + + message = response.output[2] + text_content = message.content[0].text + assert "boiling point" in text_content.lower() + + @pytest.mark.parametrize( "case", responses_test_cases["test_response_custom_tool"]["test_params"]["case"],