diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 7a6e82001..99ae1c038 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -7218,15 +7218,15 @@
"OpenAIResponseOutputMessageFunctionToolCall": {
"type": "object",
"properties": {
- "arguments": {
- "type": "string"
- },
"call_id": {
"type": "string"
},
"name": {
"type": "string"
},
+ "arguments": {
+ "type": "string"
+ },
"type": {
"type": "string",
"const": "function_call",
@@ -7241,12 +7241,10 @@
},
"additionalProperties": false,
"required": [
- "arguments",
"call_id",
"name",
- "type",
- "id",
- "status"
+ "arguments",
+ "type"
],
"title": "OpenAIResponseOutputMessageFunctionToolCall"
},
@@ -7412,6 +7410,12 @@
},
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
+ },
+ {
+ "$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPCall"
+ },
+ {
+ "$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
}
],
"discriminator": {
@@ -7419,10 +7423,118 @@
"mapping": {
"message": "#/components/schemas/OpenAIResponseMessage",
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall",
- "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
+ "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall",
+ "mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall",
+ "mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
}
}
},
+ "OpenAIResponseOutputMessageMCPCall": {
+ "type": "object",
+ "properties": {
+ "id": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string",
+ "const": "mcp_call",
+ "default": "mcp_call"
+ },
+ "arguments": {
+ "type": "string"
+ },
+ "name": {
+ "type": "string"
+ },
+ "server_label": {
+ "type": "string"
+ },
+ "error": {
+ "type": "string"
+ },
+ "output": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "id",
+ "type",
+ "arguments",
+ "name",
+ "server_label"
+ ],
+ "title": "OpenAIResponseOutputMessageMCPCall"
+ },
+ "OpenAIResponseOutputMessageMCPListTools": {
+ "type": "object",
+ "properties": {
+ "id": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string",
+ "const": "mcp_list_tools",
+ "default": "mcp_list_tools"
+ },
+ "server_label": {
+ "type": "string"
+ },
+ "tools": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "input_schema": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
+ },
+ "name": {
+ "type": "string"
+ },
+ "description": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "input_schema",
+ "name"
+ ],
+ "title": "MCPListToolsTool"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "id",
+ "type",
+ "server_label",
+ "tools"
+ ],
+ "title": "OpenAIResponseOutputMessageMCPListTools"
+ },
"OpenAIResponseObjectStream": {
"oneOf": [
{
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 74c4852d4..4e4f09eb0 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -5075,12 +5075,12 @@ components:
"OpenAIResponseOutputMessageFunctionToolCall":
type: object
properties:
- arguments:
- type: string
call_id:
type: string
name:
type: string
+ arguments:
+ type: string
type:
type: string
const: function_call
@@ -5091,12 +5091,10 @@ components:
type: string
additionalProperties: false
required:
- - arguments
- call_id
- name
+ - arguments
- type
- - id
- - status
title: >-
OpenAIResponseOutputMessageFunctionToolCall
"OpenAIResponseOutputMessageWebSearchToolCall":
@@ -5214,12 +5212,85 @@ components:
- $ref: '#/components/schemas/OpenAIResponseMessage'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
+ - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
+ - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
discriminator:
propertyName: type
mapping:
message: '#/components/schemas/OpenAIResponseMessage'
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
+ mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
+ mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
+ OpenAIResponseOutputMessageMCPCall:
+ type: object
+ properties:
+ id:
+ type: string
+ type:
+ type: string
+ const: mcp_call
+ default: mcp_call
+ arguments:
+ type: string
+ name:
+ type: string
+ server_label:
+ type: string
+ error:
+ type: string
+ output:
+ type: string
+ additionalProperties: false
+ required:
+ - id
+ - type
+ - arguments
+ - name
+ - server_label
+ title: OpenAIResponseOutputMessageMCPCall
+ OpenAIResponseOutputMessageMCPListTools:
+ type: object
+ properties:
+ id:
+ type: string
+ type:
+ type: string
+ const: mcp_list_tools
+ default: mcp_list_tools
+ server_label:
+ type: string
+ tools:
+ type: array
+ items:
+ type: object
+ properties:
+ input_schema:
+ type: object
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ name:
+ type: string
+ description:
+ type: string
+ additionalProperties: false
+ required:
+ - input_schema
+ - name
+ title: MCPListToolsTool
+ additionalProperties: false
+ required:
+ - id
+ - type
+ - server_label
+ - tools
+ title: OpenAIResponseOutputMessageMCPListTools
OpenAIResponseObjectStream:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py
index 7740b5643..675e8f3ff 100644
--- a/llama_stack/apis/agents/openai_responses.py
+++ b/llama_stack/apis/agents/openai_responses.py
@@ -10,6 +10,9 @@ from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type, register_schema
+# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
+# take their YAML and generate this file automatically. Their YAML is available.
+
@json_schema_type
class OpenAIResponseError(BaseModel):
@@ -79,16 +82,45 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
@json_schema_type
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
- arguments: str
call_id: str
name: str
+ arguments: str
type: Literal["function_call"] = "function_call"
+ id: str | None = None
+ status: str | None = None
+
+
+@json_schema_type
+class OpenAIResponseOutputMessageMCPCall(BaseModel):
id: str
- status: str
+ type: Literal["mcp_call"] = "mcp_call"
+ arguments: str
+ name: str
+ server_label: str
+ error: str | None = None
+ output: str | None = None
+
+
+class MCPListToolsTool(BaseModel):
+ input_schema: dict[str, Any]
+ name: str
+ description: str | None = None
+
+
+@json_schema_type
+class OpenAIResponseOutputMessageMCPListTools(BaseModel):
+ id: str
+ type: Literal["mcp_list_tools"] = "mcp_list_tools"
+ server_label: str
+ tools: list[MCPListToolsTool]
OpenAIResponseOutput = Annotated[
- OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
+ OpenAIResponseMessage
+ | OpenAIResponseOutputMessageWebSearchToolCall
+ | OpenAIResponseOutputMessageFunctionToolCall
+ | OpenAIResponseOutputMessageMCPCall
+ | OpenAIResponseOutputMessageMCPListTools,
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py
index 9acda3b8c..a7c6d684e 100644
--- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py
+++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py
@@ -14,6 +14,7 @@ from pydantic import BaseModel
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
+ AllowedToolsFilter,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIResponseInput,
@@ -22,7 +23,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
- OpenAIResponseInputToolFunction,
+ OpenAIResponseInputToolMCP,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
@@ -51,11 +52,12 @@ from llama_stack.apis.inference.inference import (
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
-from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime
+from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
+from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
logger = get_logger(name=__name__, category="openai_responses")
@@ -168,6 +170,15 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
response: OpenAIResponseObject
+class ChatCompletionContext(BaseModel):
+ model: str
+ messages: list[OpenAIMessageParam]
+ tools: list[ChatCompletionToolParam] | None = None
+ mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
+ stream: bool
+ temperature: float | None
+
+
class OpenAIResponsesImpl:
def __init__(
self,
@@ -255,13 +266,32 @@ class OpenAIResponsesImpl:
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
):
+ output_messages: list[OpenAIResponseOutput] = []
+
stream = False if stream is None else stream
+ # Huge TODO: we need to run this in a loop, until morale improves
+
+ # Create context to run "chat completion"
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
- chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
+ chat_tools, mcp_tool_to_server, mcp_list_message = (
+ await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {})
+ )
+ if mcp_list_message:
+ output_messages.append(mcp_list_message)
+ ctx = ChatCompletionContext(
+ model=model,
+ messages=messages,
+ tools=chat_tools,
+ mcp_tool_to_server=mcp_tool_to_server,
+ stream=stream,
+ temperature=temperature,
+ )
+
+ # Run inference
chat_response = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
@@ -270,6 +300,7 @@ class OpenAIResponsesImpl:
temperature=temperature,
)
+ # Collect output
if stream:
# TODO: refactor this into a separate method that handles streaming
chat_response_id = ""
@@ -328,11 +359,11 @@ class OpenAIResponsesImpl:
# dump and reload to map to our pydantic types
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
- output_messages: list[OpenAIResponseOutput] = []
+ # Execute tool calls if any
for choice in chat_response.choices:
if choice.message.tool_calls and tools:
# Assume if the first tool is a function, all tools are functions
- if isinstance(tools[0], OpenAIResponseInputToolFunction):
+ if tools[0].type == "function":
for tool_call in choice.message.tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
@@ -344,11 +375,12 @@ class OpenAIResponsesImpl:
)
)
else:
- output_messages.extend(
- await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature)
- )
+ tool_messages = await self._execute_tool_and_return_final_output(choice, ctx)
+ output_messages.extend(tool_messages)
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
+
+ # Create response object
response = OpenAIResponseObject(
created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}",
@@ -359,9 +391,8 @@ class OpenAIResponsesImpl:
)
logger.debug(f"OpenAI Responses response: {response}")
+ # Store response if requested
if store:
- # Store in kvstore
-
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
@@ -403,7 +434,36 @@ class OpenAIResponsesImpl:
async def _convert_response_tools_to_chat_tools(
self, tools: list[OpenAIResponseInputTool]
- ) -> list[ChatCompletionToolParam]:
+ ) -> tuple[
+ list[ChatCompletionToolParam],
+ dict[str, OpenAIResponseInputToolMCP],
+ OpenAIResponseOutput | None,
+ ]:
+ from llama_stack.apis.agents.openai_responses import (
+ MCPListToolsTool,
+ OpenAIResponseOutputMessageMCPListTools,
+ )
+ from llama_stack.apis.tools.tools import Tool
+
+ mcp_tool_to_server = {}
+
+ def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
+ tool_def = ToolDefinition(
+ tool_name=tool_name,
+ description=tool.description,
+ parameters={
+ param.name: ToolParamDefinition(
+ param_type=param.parameter_type,
+ description=param.description,
+ required=param.required,
+ default=param.default,
+ )
+ for param in tool.parameters
+ },
+ )
+ return convert_tooldef_to_openai_tool(tool_def)
+
+ mcp_list_message = None
chat_tools: list[ChatCompletionToolParam] = []
for input_tool in tools:
# TODO: Handle other tool types
@@ -412,91 +472,95 @@ class OpenAIResponsesImpl:
elif input_tool.type == "web_search":
tool_name = "web_search"
tool = await self.tool_groups_api.get_tool(tool_name)
- tool_def = ToolDefinition(
- tool_name=tool_name,
- description=tool.description,
- parameters={
- param.name: ToolParamDefinition(
- param_type=param.parameter_type,
- description=param.description,
- required=param.required,
- default=param.default,
- )
- for param in tool.parameters
- },
+ if not tool:
+ raise ValueError(f"Tool {tool_name} not found")
+ chat_tools.append(make_openai_tool(tool_name, tool))
+ elif input_tool.type == "mcp":
+ always_allowed = None
+ never_allowed = None
+ if input_tool.allowed_tools:
+ if isinstance(input_tool.allowed_tools, list):
+ always_allowed = input_tool.allowed_tools
+ elif isinstance(input_tool.allowed_tools, AllowedToolsFilter):
+ always_allowed = input_tool.allowed_tools.always
+ never_allowed = input_tool.allowed_tools.never
+
+ tool_defs = await list_mcp_tools(
+ endpoint=input_tool.server_url,
+ headers=convert_header_list_to_dict(input_tool.headers or []),
)
- chat_tool = convert_tooldef_to_openai_tool(tool_def)
- chat_tools.append(chat_tool)
+
+ mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
+ id=f"mcp_list_{uuid.uuid4()}",
+ status="completed",
+ server_label=input_tool.server_label,
+ tools=[],
+ )
+ for t in tool_defs.data:
+ if never_allowed and t.name in never_allowed:
+ continue
+ if not always_allowed or t.name in always_allowed:
+ chat_tools.append(make_openai_tool(t.name, t))
+ if t.name in mcp_tool_to_server:
+ raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}")
+ mcp_tool_to_server[t.name] = input_tool
+ mcp_list_message.tools.append(
+ MCPListToolsTool(
+ name=t.name,
+ description=t.description,
+ input_schema={
+ "type": "object",
+ "properties": {
+ p.name: {
+ "type": p.parameter_type,
+ "description": p.description,
+ }
+ for p in t.parameters
+ },
+ "required": [p.name for p in t.parameters if p.required],
+ },
+ )
+ )
else:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
- return chat_tools
+ return chat_tools, mcp_tool_to_server, mcp_list_message
async def _execute_tool_and_return_final_output(
self,
- model_id: str,
- stream: bool,
choice: OpenAIChoice,
- messages: list[OpenAIMessageParam],
- temperature: float,
+ ctx: ChatCompletionContext,
) -> list[OpenAIResponseOutput]:
output_messages: list[OpenAIResponseOutput] = []
- # If the choice is not an assistant message, we don't need to execute any tools
if not isinstance(choice.message, OpenAIAssistantMessageParam):
return output_messages
- # If the assistant message doesn't have any tool calls, we don't need to execute any tools
if not choice.message.tool_calls:
return output_messages
- # Copy the messages list to avoid mutating the original list
- messages = messages.copy()
+ next_turn_messages = ctx.messages.copy()
# Add the assistant message with tool_calls response to the messages list
- messages.append(choice.message)
+ next_turn_messages.append(choice.message)
for tool_call in choice.message.tool_calls:
- tool_call_id = tool_call.id
- function = tool_call.function
-
- # If for some reason the tool call doesn't have a function or id, we can't execute it
- if not function or not tool_call_id:
- continue
-
# TODO: telemetry spans for tool calls
- result = await self._execute_tool_call(function)
-
- # Handle tool call failure
- if not result:
- output_messages.append(
- OpenAIResponseOutputMessageWebSearchToolCall(
- id=tool_call_id,
- status="failed",
- )
- )
- continue
-
- output_messages.append(
- OpenAIResponseOutputMessageWebSearchToolCall(
- id=tool_call_id,
- status="completed",
- ),
- )
-
- result_content = ""
- # TODO: handle other result content types and lists
- if isinstance(result.content, str):
- result_content = result.content
- messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id))
+ tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
+ if tool_call_log:
+ output_messages.append(tool_call_log)
+ if further_input:
+ next_turn_messages.append(further_input)
tool_results_chat_response = await self.inference_api.openai_chat_completion(
- model=model_id,
- messages=messages,
- stream=stream,
- temperature=temperature,
+ model=ctx.model,
+ messages=next_turn_messages,
+ stream=ctx.stream,
+ temperature=ctx.temperature,
)
- # type cast to appease mypy
+ # type cast to appease mypy: this is needed because we don't handle streaming properly :)
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
+
+ # Huge TODO: these are NOT the final outputs, we must keep the loop going
tool_final_outputs = [
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
]
@@ -506,15 +570,85 @@ class OpenAIResponsesImpl:
async def _execute_tool_call(
self,
- function: OpenAIChatCompletionToolCallFunction,
- ) -> ToolInvocationResult | None:
- if not function.name:
- return None
- function_args = json.loads(function.arguments) if function.arguments else {}
- logger.info(f"executing tool call: {function.name} with args: {function_args}")
- result = await self.tool_runtime_api.invoke_tool(
- tool_name=function.name,
- kwargs=function_args,
+ tool_call: OpenAIChatCompletionToolCall,
+ ctx: ChatCompletionContext,
+ ) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
+ from llama_stack.providers.utils.inference.prompt_adapter import (
+ interleaved_content_as_str,
)
- logger.debug(f"tool call {function.name} completed with result: {result}")
- return result
+
+ tool_call_id = tool_call.id
+ function = tool_call.function
+
+ if not function or not tool_call_id or not function.name:
+ return None, None
+
+ error_exc = None
+ try:
+ if function.name in ctx.mcp_tool_to_server:
+ mcp_tool = ctx.mcp_tool_to_server[function.name]
+ result = await invoke_mcp_tool(
+ endpoint=mcp_tool.server_url,
+ headers=convert_header_list_to_dict(mcp_tool.headers or []),
+ tool_name=function.name,
+ kwargs=json.loads(function.arguments) if function.arguments else {},
+ )
+ else:
+ result = await self.tool_runtime_api.invoke_tool(
+ tool_name=function.name,
+ kwargs=json.loads(function.arguments) if function.arguments else {},
+ )
+ except Exception as e:
+ error_exc = e
+
+ if function.name in ctx.mcp_tool_to_server:
+ from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall
+
+ message = OpenAIResponseOutputMessageMCPCall(
+ id=tool_call_id,
+ arguments=function.arguments,
+ name=function.name,
+ server_label=ctx.mcp_tool_to_server[function.name].server_label,
+ )
+ if error_exc:
+ message.error = str(error_exc)
+ elif (result.error_code and result.error_code > 0) or result.error_message:
+ message.error = f"Error (code {result.error_code}): {result.error_message}"
+ elif result.content:
+ message.output = interleaved_content_as_str(result.content)
+ else:
+ if function.name == "web_search":
+ message = OpenAIResponseOutputMessageWebSearchToolCall(
+ id=tool_call_id,
+ status="completed",
+ )
+ if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
+ message.status = "failed"
+ else:
+ raise ValueError(f"Unknown tool {function.name} called")
+
+ input_message = None
+ if result.content:
+ if isinstance(result.content, str):
+ content = result.content
+ elif isinstance(result.content, list):
+ from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
+
+ content = []
+ for item in result.content:
+ if isinstance(item, TextContentItem):
+ part = OpenAIChatCompletionContentPartTextParam(text=item.text)
+ elif isinstance(item, ImageContentItem):
+ if item.image.data:
+ url = f"data:image;base64,{item.image.data}"
+ else:
+ url = item.image.url
+ part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
+ else:
+ raise ValueError(f"Unknown result content type: {type(item)}")
+ content.append(part)
+ else:
+ raise ValueError(f"Unknown result content type: {type(result.content)}")
+ input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
+
+ return message, input_message
diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py
index 340e90ca1..3f0b9a188 100644
--- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py
+++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py
@@ -4,53 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urlparse
-import exceptiongroup
-import httpx
-from mcp import ClientSession
-from mcp import types as mcp_types
-from mcp.client.sse import sse_client
-
-from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem
+from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import (
ListToolDefsResponse,
- ToolDef,
ToolInvocationResult,
- ToolParameter,
ToolRuntime,
)
-from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolsProtocolPrivate
+from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools")
-@asynccontextmanager
-async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
- try:
- async with sse_client(endpoint, headers=headers) as streams:
- async with ClientSession(*streams) as session:
- await session.initialize()
- yield session
- except BaseException as e:
- if isinstance(e, exceptiongroup.BaseExceptionGroup):
- for exc in e.exceptions:
- if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
- raise AuthenticationRequiredError(exc) from exc
- elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
- raise AuthenticationRequiredError(e) from e
-
- raise
-
-
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config
@@ -64,32 +37,8 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
# this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
-
headers = await self.get_headers_from_request(mcp_endpoint.uri)
- tools = []
- async with sse_client_wrapper(mcp_endpoint.uri, headers) as session:
- tools_result = await session.list_tools()
- for tool in tools_result.tools:
- parameters = []
- for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
- parameters.append(
- ToolParameter(
- name=param_name,
- parameter_type=param_schema.get("type", "string"),
- description=param_schema.get("description", ""),
- )
- )
- tools.append(
- ToolDef(
- name=tool.name,
- description=tool.description,
- parameters=parameters,
- metadata={
- "endpoint": mcp_endpoint.uri,
- },
- )
- )
- return ListToolDefsResponse(data=tools)
+ return await list_mcp_tools(mcp_endpoint.uri, headers)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
@@ -100,23 +49,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
headers = await self.get_headers_from_request(endpoint)
- async with sse_client_wrapper(endpoint, headers) as session:
- result = await session.call_tool(tool.identifier, kwargs)
-
- content = []
- for item in result.content:
- if isinstance(item, mcp_types.TextContent):
- content.append(TextContentItem(text=item.text))
- elif isinstance(item, mcp_types.ImageContent):
- content.append(ImageContentItem(image=item.data))
- elif isinstance(item, mcp_types.EmbeddedResource):
- logger.warning(f"EmbeddedResource is not supported: {item}")
- else:
- raise ValueError(f"Unknown content type: {type(item)}")
- return ToolInvocationResult(
- content=content,
- error_code=1 if result.isError else 0,
- )
+ return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
def canonicalize_uri(uri: str) -> str:
@@ -129,9 +62,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
- for entry in values:
- parts = entry.split(":")
- if len(parts) == 2:
- k, v = parts
- headers[k.strip()] = v.strip()
+ headers.update(convert_header_list_to_dict(values))
return headers
diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py
new file mode 100644
index 000000000..afa4df766
--- /dev/null
+++ b/llama_stack/providers/utils/tools/mcp.py
@@ -0,0 +1,103 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from contextlib import asynccontextmanager
+from typing import Any
+
+import exceptiongroup
+import httpx
+from mcp import ClientSession
+from mcp import types as mcp_types
+from mcp.client.sse import sse_client
+
+from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
+from llama_stack.apis.tools import (
+ ListToolDefsResponse,
+ ToolDef,
+ ToolInvocationResult,
+ ToolParameter,
+)
+from llama_stack.distribution.datatypes import AuthenticationRequiredError
+from llama_stack.log import get_logger
+
+logger = get_logger(__name__, category="tools")
+
+
+@asynccontextmanager
+async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
+ try:
+ async with sse_client(endpoint, headers=headers) as streams:
+ async with ClientSession(*streams) as session:
+ await session.initialize()
+ yield session
+ except BaseException as e:
+ if isinstance(e, exceptiongroup.BaseExceptionGroup):
+ for exc in e.exceptions:
+ if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
+ raise AuthenticationRequiredError(exc) from exc
+ elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
+ raise AuthenticationRequiredError(e) from e
+
+ raise
+
+
+def convert_header_list_to_dict(header_list: list[str]) -> dict[str, str]:
+ headers = {}
+ for header in header_list:
+ parts = header.split(":")
+ if len(parts) == 2:
+ k, v = parts
+ headers[k.strip()] = v.strip()
+ return headers
+
+
+async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
+ tools = []
+ async with sse_client_wrapper(endpoint, headers) as session:
+ tools_result = await session.list_tools()
+ for tool in tools_result.tools:
+ parameters = []
+ for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
+ parameters.append(
+ ToolParameter(
+ name=param_name,
+ parameter_type=param_schema.get("type", "string"),
+ description=param_schema.get("description", ""),
+ )
+ )
+ tools.append(
+ ToolDef(
+ name=tool.name,
+ description=tool.description,
+ parameters=parameters,
+ metadata={
+ "endpoint": endpoint,
+ },
+ )
+ )
+ return ListToolDefsResponse(data=tools)
+
+
+async def invoke_mcp_tool(
+ endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
+) -> ToolInvocationResult:
+ async with sse_client_wrapper(endpoint, headers) as session:
+ result = await session.call_tool(tool_name, kwargs)
+
+ content: list[InterleavedContentItem] = []
+ for item in result.content:
+ if isinstance(item, mcp_types.TextContent):
+ content.append(TextContentItem(text=item.text))
+ elif isinstance(item, mcp_types.ImageContent):
+ content.append(ImageContentItem(image=item.data))
+ elif isinstance(item, mcp_types.EmbeddedResource):
+ logger.warning(f"EmbeddedResource is not supported: {item}")
+ else:
+ raise ValueError(f"Unknown content type: {type(item)}")
+ return ToolInvocationResult(
+ content=content,
+ error_code=1 if result.isError else 0,
+ )
diff --git a/tests/common/mcp.py b/tests/common/mcp.py
new file mode 100644
index 000000000..f602cbff2
--- /dev/null
+++ b/tests/common/mcp.py
@@ -0,0 +1,116 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+# we want the mcp server to be authenticated OR not, depends
+from contextlib import contextmanager
+
+
+@contextmanager
+def make_mcp_server(required_auth_token: str | None = None):
+ import threading
+ import time
+
+ import httpx
+ import uvicorn
+ from mcp import types
+ from mcp.server.fastmcp import Context, FastMCP
+ from mcp.server.sse import SseServerTransport
+ from starlette.applications import Starlette
+ from starlette.responses import Response
+ from starlette.routing import Mount, Route
+
+ server = FastMCP("FastMCP Test Server")
+
+ @server.tool()
+ async def greet_everyone(
+ url: str, ctx: Context
+ ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
+ return [types.TextContent(type="text", text="Hello, world!")]
+
+ @server.tool()
+ async def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
+ """
+ Returns the boiling point of a liquid in Celcius or Fahrenheit.
+
+ :param liquid_name: The name of the liquid
+ :param celcius: Whether to return the boiling point in Celcius
+ :return: The boiling point of the liquid in Celcius or Fahrenheit
+ """
+ if liquid_name.lower() == "polyjuice":
+ if celcius:
+ return -100
+ else:
+ return -212
+ else:
+ return -1
+
+ sse = SseServerTransport("/messages/")
+
+ async def handle_sse(request):
+ from starlette.exceptions import HTTPException
+
+ auth_header = request.headers.get("Authorization")
+ auth_token = None
+ if auth_header and auth_header.startswith("Bearer "):
+ auth_token = auth_header.split(" ")[1]
+
+ if required_auth_token and auth_token != required_auth_token:
+ raise HTTPException(status_code=401, detail="Unauthorized")
+
+ async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
+ await server._mcp_server.run(
+ streams[0],
+ streams[1],
+ server._mcp_server.create_initialization_options(),
+ )
+ return Response()
+
+ app = Starlette(
+ routes=[
+ Route("/sse", endpoint=handle_sse),
+ Mount("/messages/", app=sse.handle_post_message),
+ ],
+ )
+
+ def get_open_port():
+ import socket
+
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ sock.bind(("", 0))
+ return sock.getsockname()[1]
+
+ port = get_open_port()
+
+ config = uvicorn.Config(app, host="0.0.0.0", port=port)
+ server_instance = uvicorn.Server(config)
+ app.state.uvicorn_server = server_instance
+
+ def run_server():
+ server_instance.run()
+
+ # Start the server in a new thread
+ server_thread = threading.Thread(target=run_server, daemon=True)
+ server_thread.start()
+
+ # Polling until the server is ready
+ timeout = 10
+ start_time = time.time()
+
+ server_url = f"http://localhost:{port}/sse"
+ while time.time() - start_time < timeout:
+ try:
+ response = httpx.get(server_url)
+ if response.status_code in [200, 401]:
+ break
+ except httpx.RequestError:
+ pass
+ time.sleep(0.1)
+
+ yield {"server_url": server_url}
+
+ # Tell server to exit
+ server_instance.should_exit = True
+ server_thread.join(timeout=5)
diff --git a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml
index 262d82526..a50abef44 100644
--- a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml
+++ b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml
@@ -31,6 +31,20 @@ test_response_web_search:
search_context_size: "low"
output: "128"
+test_response_mcp_tool:
+ test_name: test_response_mcp_tool
+ test_params:
+ case:
+ - case_id: "boiling_point_tool"
+ input: "What is the boiling point of polyjuice?"
+ tools:
+ - type: mcp
+ server_label: "localmcp"
+ server_url: ""
+ headers:
+ Authorization: "Bearer test-token"
+ output: "Hello, world!"
+
test_response_custom_tool:
test_name: test_response_custom_tool
test_params:
diff --git a/tests/verifications/openai_api/test_responses.py b/tests/verifications/openai_api/test_responses.py
index e279b9b38..b1c04e2f3 100644
--- a/tests/verifications/openai_api/test_responses.py
+++ b/tests/verifications/openai_api/test_responses.py
@@ -4,9 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
+import json
import pytest
+from tests.common.mcp import make_mcp_server
from tests.verifications.openai_api.fixtures.fixtures import (
case_id_generator,
get_base_test_name,
@@ -124,6 +126,47 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid
assert case["output"].lower() in response.output_text.lower().strip()
+@pytest.mark.parametrize(
+ "case",
+ responses_test_cases["test_response_mcp_tool"]["test_params"]["case"],
+ ids=case_id_generator,
+)
+def test_response_non_streaming_mcp_tool(request, openai_client, model, provider, verification_config, case):
+ test_name_base = get_base_test_name(request)
+ if should_skip_test(verification_config, provider, model, test_name_base):
+ pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
+
+ with make_mcp_server() as mcp_server_info:
+ tools = case["tools"]
+ for tool in tools:
+ if tool["type"] == "mcp":
+ tool["server_url"] = mcp_server_info["server_url"]
+
+ response = openai_client.responses.create(
+ model=model,
+ input=case["input"],
+ tools=tools,
+ stream=False,
+ )
+ assert len(response.output) >= 3
+ list_tools = response.output[0]
+ assert list_tools.type == "mcp_list_tools"
+ assert list_tools.server_label == "localmcp"
+ assert len(list_tools.tools) == 2
+ assert {t["name"] for t in list_tools.tools} == {"get_boiling_point", "greet_everyone"}
+
+ call = response.output[1]
+ assert call.type == "mcp_call"
+ assert call.name == "get_boiling_point"
+ assert json.loads(call.arguments) == {"liquid_name": "polyjuice", "celcius": True}
+ assert call.error is None
+ assert "-100" in call.output
+
+ message = response.output[2]
+ text_content = message.content[0].text
+ assert "boiling point" in text_content.lower()
+
+
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_custom_tool"]["test_params"]["case"],