feat: enable MCP execution in Response implementation

This commit is contained in:
Ashwin Bharambe 2025-05-22 20:21:47 -07:00
parent a411029d7e
commit 5937d94da5
9 changed files with 728 additions and 174 deletions

View file

@ -7218,15 +7218,15 @@
"OpenAIResponseOutputMessageFunctionToolCall": {
"type": "object",
"properties": {
"arguments": {
"type": "string"
},
"call_id": {
"type": "string"
},
"name": {
"type": "string"
},
"arguments": {
"type": "string"
},
"type": {
"type": "string",
"const": "function_call",
@ -7241,12 +7241,10 @@
},
"additionalProperties": false,
"required": [
"arguments",
"call_id",
"name",
"type",
"id",
"status"
"arguments",
"type"
],
"title": "OpenAIResponseOutputMessageFunctionToolCall"
},
@ -7412,6 +7410,12 @@
},
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
},
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPCall"
},
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
}
],
"discriminator": {
@ -7419,10 +7423,118 @@
"mapping": {
"message": "#/components/schemas/OpenAIResponseMessage",
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall",
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall",
"mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall",
"mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
}
}
},
"OpenAIResponseOutputMessageMCPCall": {
"type": "object",
"properties": {
"id": {
"type": "string"
},
"type": {
"type": "string",
"const": "mcp_call",
"default": "mcp_call"
},
"arguments": {
"type": "string"
},
"name": {
"type": "string"
},
"server_label": {
"type": "string"
},
"error": {
"type": "string"
},
"output": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"id",
"type",
"arguments",
"name",
"server_label"
],
"title": "OpenAIResponseOutputMessageMCPCall"
},
"OpenAIResponseOutputMessageMCPListTools": {
"type": "object",
"properties": {
"id": {
"type": "string"
},
"type": {
"type": "string",
"const": "mcp_list_tools",
"default": "mcp_list_tools"
},
"server_label": {
"type": "string"
},
"tools": {
"type": "array",
"items": {
"type": "object",
"properties": {
"input_schema": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"name": {
"type": "string"
},
"description": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"input_schema",
"name"
],
"title": "MCPListToolsTool"
}
}
},
"additionalProperties": false,
"required": [
"id",
"type",
"server_label",
"tools"
],
"title": "OpenAIResponseOutputMessageMCPListTools"
},
"OpenAIResponseObjectStream": {
"oneOf": [
{

View file

@ -5075,12 +5075,12 @@ components:
"OpenAIResponseOutputMessageFunctionToolCall":
type: object
properties:
arguments:
type: string
call_id:
type: string
name:
type: string
arguments:
type: string
type:
type: string
const: function_call
@ -5091,12 +5091,10 @@ components:
type: string
additionalProperties: false
required:
- arguments
- call_id
- name
- arguments
- type
- id
- status
title: >-
OpenAIResponseOutputMessageFunctionToolCall
"OpenAIResponseOutputMessageWebSearchToolCall":
@ -5214,12 +5212,85 @@ components:
- $ref: '#/components/schemas/OpenAIResponseMessage'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
discriminator:
propertyName: type
mapping:
message: '#/components/schemas/OpenAIResponseMessage'
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
OpenAIResponseOutputMessageMCPCall:
type: object
properties:
id:
type: string
type:
type: string
const: mcp_call
default: mcp_call
arguments:
type: string
name:
type: string
server_label:
type: string
error:
type: string
output:
type: string
additionalProperties: false
required:
- id
- type
- arguments
- name
- server_label
title: OpenAIResponseOutputMessageMCPCall
OpenAIResponseOutputMessageMCPListTools:
type: object
properties:
id:
type: string
type:
type: string
const: mcp_list_tools
default: mcp_list_tools
server_label:
type: string
tools:
type: array
items:
type: object
properties:
input_schema:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
name:
type: string
description:
type: string
additionalProperties: false
required:
- input_schema
- name
title: MCPListToolsTool
additionalProperties: false
required:
- id
- type
- server_label
- tools
title: OpenAIResponseOutputMessageMCPListTools
OpenAIResponseObjectStream:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'

View file

@ -10,6 +10,9 @@ from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type, register_schema
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
# take their YAML and generate this file automatically. Their YAML is available.
@json_schema_type
class OpenAIResponseError(BaseModel):
@ -79,16 +82,45 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
@json_schema_type
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
arguments: str
call_id: str
name: str
arguments: str
type: Literal["function_call"] = "function_call"
id: str | None = None
status: str | None = None
@json_schema_type
class OpenAIResponseOutputMessageMCPCall(BaseModel):
id: str
status: str
type: Literal["mcp_call"] = "mcp_call"
arguments: str
name: str
server_label: str
error: str | None = None
output: str | None = None
class MCPListToolsTool(BaseModel):
input_schema: dict[str, Any]
name: str
description: str | None = None
@json_schema_type
class OpenAIResponseOutputMessageMCPListTools(BaseModel):
id: str
type: Literal["mcp_list_tools"] = "mcp_list_tools"
server_label: str
tools: list[MCPListToolsTool]
OpenAIResponseOutput = Annotated[
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseMessage
| OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools,
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")

View file

@ -14,6 +14,7 @@ from pydantic import BaseModel
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
AllowedToolsFilter,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIResponseInput,
@ -22,7 +23,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
@ -51,11 +52,12 @@ from llama_stack.apis.inference.inference import (
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
logger = get_logger(name=__name__, category="openai_responses")
@ -168,6 +170,15 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
response: OpenAIResponseObject
class ChatCompletionContext(BaseModel):
model: str
messages: list[OpenAIMessageParam]
tools: list[ChatCompletionToolParam] | None = None
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
stream: bool
temperature: float | None
class OpenAIResponsesImpl:
def __init__(
self,
@ -255,13 +266,32 @@ class OpenAIResponsesImpl:
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
):
output_messages: list[OpenAIResponseOutput] = []
stream = False if stream is None else stream
# Huge TODO: we need to run this in a loop, until morale improves
# Create context to run "chat completion"
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {})
)
if mcp_list_message:
output_messages.append(mcp_list_message)
ctx = ChatCompletionContext(
model=model,
messages=messages,
tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
stream=stream,
temperature=temperature,
)
# Run inference
chat_response = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
@ -270,6 +300,7 @@ class OpenAIResponsesImpl:
temperature=temperature,
)
# Collect output
if stream:
# TODO: refactor this into a separate method that handles streaming
chat_response_id = ""
@ -328,11 +359,11 @@ class OpenAIResponsesImpl:
# dump and reload to map to our pydantic types
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
output_messages: list[OpenAIResponseOutput] = []
# Execute tool calls if any
for choice in chat_response.choices:
if choice.message.tool_calls and tools:
# Assume if the first tool is a function, all tools are functions
if isinstance(tools[0], OpenAIResponseInputToolFunction):
if tools[0].type == "function":
for tool_call in choice.message.tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
@ -344,11 +375,12 @@ class OpenAIResponsesImpl:
)
)
else:
output_messages.extend(
await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature)
)
tool_messages = await self._execute_tool_and_return_final_output(choice, ctx)
output_messages.extend(tool_messages)
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
# Create response object
response = OpenAIResponseObject(
created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}",
@ -359,9 +391,8 @@ class OpenAIResponsesImpl:
)
logger.debug(f"OpenAI Responses response: {response}")
# Store response if requested
if store:
# Store in kvstore
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
@ -403,7 +434,36 @@ class OpenAIResponsesImpl:
async def _convert_response_tools_to_chat_tools(
self, tools: list[OpenAIResponseInputTool]
) -> list[ChatCompletionToolParam]:
) -> tuple[
list[ChatCompletionToolParam],
dict[str, OpenAIResponseInputToolMCP],
OpenAIResponseOutput | None,
]:
from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
OpenAIResponseOutputMessageMCPListTools,
)
from llama_stack.apis.tools.tools import Tool
mcp_tool_to_server = {}
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool.parameters
},
)
return convert_tooldef_to_openai_tool(tool_def)
mcp_list_message = None
chat_tools: list[ChatCompletionToolParam] = []
for input_tool in tools:
# TODO: Handle other tool types
@ -412,91 +472,95 @@ class OpenAIResponsesImpl:
elif input_tool.type == "web_search":
tool_name = "web_search"
tool = await self.tool_groups_api.get_tool(tool_name)
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool.parameters
},
if not tool:
raise ValueError(f"Tool {tool_name} not found")
chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "mcp":
always_allowed = None
never_allowed = None
if input_tool.allowed_tools:
if isinstance(input_tool.allowed_tools, list):
always_allowed = input_tool.allowed_tools
elif isinstance(input_tool.allowed_tools, AllowedToolsFilter):
always_allowed = input_tool.allowed_tools.always
never_allowed = input_tool.allowed_tools.never
tool_defs = await list_mcp_tools(
endpoint=input_tool.server_url,
headers=convert_header_list_to_dict(input_tool.headers or []),
)
chat_tool = convert_tooldef_to_openai_tool(tool_def)
chat_tools.append(chat_tool)
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",
status="completed",
server_label=input_tool.server_label,
tools=[],
)
for t in tool_defs.data:
if never_allowed and t.name in never_allowed:
continue
if not always_allowed or t.name in always_allowed:
chat_tools.append(make_openai_tool(t.name, t))
if t.name in mcp_tool_to_server:
raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}")
mcp_tool_to_server[t.name] = input_tool
mcp_list_message.tools.append(
MCPListToolsTool(
name=t.name,
description=t.description,
input_schema={
"type": "object",
"properties": {
p.name: {
"type": p.parameter_type,
"description": p.description,
}
for p in t.parameters
},
"required": [p.name for p in t.parameters if p.required],
},
)
)
else:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools
return chat_tools, mcp_tool_to_server, mcp_list_message
async def _execute_tool_and_return_final_output(
self,
model_id: str,
stream: bool,
choice: OpenAIChoice,
messages: list[OpenAIMessageParam],
temperature: float,
ctx: ChatCompletionContext,
) -> list[OpenAIResponseOutput]:
output_messages: list[OpenAIResponseOutput] = []
# If the choice is not an assistant message, we don't need to execute any tools
if not isinstance(choice.message, OpenAIAssistantMessageParam):
return output_messages
# If the assistant message doesn't have any tool calls, we don't need to execute any tools
if not choice.message.tool_calls:
return output_messages
# Copy the messages list to avoid mutating the original list
messages = messages.copy()
next_turn_messages = ctx.messages.copy()
# Add the assistant message with tool_calls response to the messages list
messages.append(choice.message)
next_turn_messages.append(choice.message)
for tool_call in choice.message.tool_calls:
tool_call_id = tool_call.id
function = tool_call.function
# If for some reason the tool call doesn't have a function or id, we can't execute it
if not function or not tool_call_id:
continue
# TODO: telemetry spans for tool calls
result = await self._execute_tool_call(function)
# Handle tool call failure
if not result:
output_messages.append(
OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="failed",
)
)
continue
output_messages.append(
OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="completed",
),
)
result_content = ""
# TODO: handle other result content types and lists
if isinstance(result.content, str):
result_content = result.content
messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id))
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
if tool_call_log:
output_messages.append(tool_call_log)
if further_input:
next_turn_messages.append(further_input)
tool_results_chat_response = await self.inference_api.openai_chat_completion(
model=model_id,
messages=messages,
stream=stream,
temperature=temperature,
model=ctx.model,
messages=next_turn_messages,
stream=ctx.stream,
temperature=ctx.temperature,
)
# type cast to appease mypy
# type cast to appease mypy: this is needed because we don't handle streaming properly :)
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
# Huge TODO: these are NOT the final outputs, we must keep the loop going
tool_final_outputs = [
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
]
@ -506,15 +570,85 @@ class OpenAIResponsesImpl:
async def _execute_tool_call(
self,
function: OpenAIChatCompletionToolCallFunction,
) -> ToolInvocationResult | None:
if not function.name:
return None
function_args = json.loads(function.arguments) if function.arguments else {}
logger.info(f"executing tool call: {function.name} with args: {function_args}")
result = await self.tool_runtime_api.invoke_tool(
tool_name=function.name,
kwargs=function_args,
tool_call: OpenAIChatCompletionToolCall,
ctx: ChatCompletionContext,
) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
logger.debug(f"tool call {function.name} completed with result: {result}")
return result
tool_call_id = tool_call.id
function = tool_call.function
if not function or not tool_call_id or not function.name:
return None, None
error_exc = None
try:
if function.name in ctx.mcp_tool_to_server:
mcp_tool = ctx.mcp_tool_to_server[function.name]
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=convert_header_list_to_dict(mcp_tool.headers or []),
tool_name=function.name,
kwargs=json.loads(function.arguments) if function.arguments else {},
)
else:
result = await self.tool_runtime_api.invoke_tool(
tool_name=function.name,
kwargs=json.loads(function.arguments) if function.arguments else {},
)
except Exception as e:
error_exc = e
if function.name in ctx.mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall
message = OpenAIResponseOutputMessageMCPCall(
id=tool_call_id,
arguments=function.arguments,
name=function.name,
server_label=ctx.mcp_tool_to_server[function.name].server_label,
)
if error_exc:
message.error = str(error_exc)
elif (result.error_code and result.error_code > 0) or result.error_message:
message.error = f"Error (code {result.error_code}): {result.error_message}"
elif result.content:
message.output = interleaved_content_as_str(result.content)
else:
if function.name == "web_search":
message = OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="completed",
)
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
message.status = "failed"
else:
raise ValueError(f"Unknown tool {function.name} called")
input_message = None
if result.content:
if isinstance(result.content, str):
content = result.content
elif isinstance(result.content, list):
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
content = []
for item in result.content:
if isinstance(item, TextContentItem):
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
elif isinstance(item, ImageContentItem):
if item.image.data:
url = f"data:image;base64,{item.image.data}"
else:
url = item.image.url
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
else:
raise ValueError(f"Unknown result content type: {type(item)}")
content.append(part)
else:
raise ValueError(f"Unknown result content type: {type(result.content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
return message, input_message

View file

@ -4,53 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urlparse
import exceptiongroup
import httpx
from mcp import ClientSession
from mcp import types as mcp_types
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools")
@asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
try:
async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session
except BaseException as e:
if isinstance(e, exceptiongroup.BaseExceptionGroup):
for exc in e.exceptions:
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
raise AuthenticationRequiredError(e) from e
raise
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config
@ -64,32 +37,8 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
# this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
headers = await self.get_headers_from_request(mcp_endpoint.uri)
tools = []
async with sse_client_wrapper(mcp_endpoint.uri, headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"endpoint": mcp_endpoint.uri,
},
)
)
return ListToolDefsResponse(data=tools)
return await list_mcp_tools(mcp_endpoint.uri, headers)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
@ -100,23 +49,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
headers = await self.get_headers_from_request(endpoint)
async with sse_client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool.identifier, kwargs)
content = []
for item in result.content:
if isinstance(item, mcp_types.TextContent):
content.append(TextContentItem(text=item.text))
elif isinstance(item, mcp_types.ImageContent):
content.append(ImageContentItem(image=item.data))
elif isinstance(item, mcp_types.EmbeddedResource):
logger.warning(f"EmbeddedResource is not supported: {item}")
else:
raise ValueError(f"Unknown content type: {type(item)}")
return ToolInvocationResult(
content=content,
error_code=1 if result.isError else 0,
)
return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
def canonicalize_uri(uri: str) -> str:
@ -129,9 +62,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
for entry in values:
parts = entry.split(":")
if len(parts) == 2:
k, v = parts
headers[k.strip()] = v.strip()
headers.update(convert_header_list_to_dict(values))
return headers

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from contextlib import asynccontextmanager
from typing import Any
import exceptiongroup
import httpx
from mcp import ClientSession
from mcp import types as mcp_types
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef,
ToolInvocationResult,
ToolParameter,
)
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.log import get_logger
logger = get_logger(__name__, category="tools")
@asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
try:
async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session
except BaseException as e:
if isinstance(e, exceptiongroup.BaseExceptionGroup):
for exc in e.exceptions:
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
raise AuthenticationRequiredError(e) from e
raise
def convert_header_list_to_dict(header_list: list[str]) -> dict[str, str]:
headers = {}
for header in header_list:
parts = header.split(":")
if len(parts) == 2:
k, v = parts
headers[k.strip()] = v.strip()
return headers
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
tools = []
async with sse_client_wrapper(endpoint, headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"endpoint": endpoint,
},
)
)
return ListToolDefsResponse(data=tools)
async def invoke_mcp_tool(
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
) -> ToolInvocationResult:
async with sse_client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool_name, kwargs)
content: list[InterleavedContentItem] = []
for item in result.content:
if isinstance(item, mcp_types.TextContent):
content.append(TextContentItem(text=item.text))
elif isinstance(item, mcp_types.ImageContent):
content.append(ImageContentItem(image=item.data))
elif isinstance(item, mcp_types.EmbeddedResource):
logger.warning(f"EmbeddedResource is not supported: {item}")
else:
raise ValueError(f"Unknown content type: {type(item)}")
return ToolInvocationResult(
content=content,
error_code=1 if result.isError else 0,
)

116
tests/common/mcp.py Normal file
View file

@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# we want the mcp server to be authenticated OR not, depends
from contextlib import contextmanager
@contextmanager
def make_mcp_server(required_auth_token: str | None = None):
import threading
import time
import httpx
import uvicorn
from mcp import types
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.responses import Response
from starlette.routing import Mount, Route
server = FastMCP("FastMCP Test Server")
@server.tool()
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
@server.tool()
async def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
"""
Returns the boiling point of a liquid in Celcius or Fahrenheit.
:param liquid_name: The name of the liquid
:param celcius: Whether to return the boiling point in Celcius
:return: The boiling point of the liquid in Celcius or Fahrenheit
"""
if liquid_name.lower() == "polyjuice":
if celcius:
return -100
else:
return -212
else:
return -1
sse = SseServerTransport("/messages/")
async def handle_sse(request):
from starlette.exceptions import HTTPException
auth_header = request.headers.get("Authorization")
auth_token = None
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.split(" ")[1]
if required_auth_token and auth_token != required_auth_token:
raise HTTPException(status_code=401, detail="Unauthorized")
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await server._mcp_server.run(
streams[0],
streams[1],
server._mcp_server.create_initialization_options(),
)
return Response()
app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
def get_open_port():
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
port = get_open_port()
config = uvicorn.Config(app, host="0.0.0.0", port=port)
server_instance = uvicorn.Server(config)
app.state.uvicorn_server = server_instance
def run_server():
server_instance.run()
# Start the server in a new thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Polling until the server is ready
timeout = 10
start_time = time.time()
server_url = f"http://localhost:{port}/sse"
while time.time() - start_time < timeout:
try:
response = httpx.get(server_url)
if response.status_code in [200, 401]:
break
except httpx.RequestError:
pass
time.sleep(0.1)
yield {"server_url": server_url}
# Tell server to exit
server_instance.should_exit = True
server_thread.join(timeout=5)

View file

@ -31,6 +31,20 @@ test_response_web_search:
search_context_size: "low"
output: "128"
test_response_mcp_tool:
test_name: test_response_mcp_tool
test_params:
case:
- case_id: "boiling_point_tool"
input: "What is the boiling point of polyjuice?"
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
headers:
Authorization: "Bearer test-token"
output: "Hello, world!"
test_response_custom_tool:
test_name: test_response_custom_tool
test_params:

View file

@ -4,9 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pytest
from tests.common.mcp import make_mcp_server
from tests.verifications.openai_api.fixtures.fixtures import (
case_id_generator,
get_base_test_name,
@ -124,6 +126,47 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid
assert case["output"].lower() in response.output_text.lower().strip()
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_mcp_tool"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_non_streaming_mcp_tool(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
with make_mcp_server() as mcp_server_info:
tools = case["tools"]
for tool in tools:
if tool["type"] == "mcp":
tool["server_url"] = mcp_server_info["server_url"]
response = openai_client.responses.create(
model=model,
input=case["input"],
tools=tools,
stream=False,
)
assert len(response.output) >= 3
list_tools = response.output[0]
assert list_tools.type == "mcp_list_tools"
assert list_tools.server_label == "localmcp"
assert len(list_tools.tools) == 2
assert {t["name"] for t in list_tools.tools} == {"get_boiling_point", "greet_everyone"}
call = response.output[1]
assert call.type == "mcp_call"
assert call.name == "get_boiling_point"
assert json.loads(call.arguments) == {"liquid_name": "polyjuice", "celcius": True}
assert call.error is None
assert "-100" in call.output
message = response.output[2]
text_content = message.content[0].text
assert "boiling point" in text_content.lower()
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_custom_tool"]["test_params"]["case"],