feat: enable MCP execution in Responses impl (#2240)

## Test Plan

```
pytest -s -v 'tests/verifications/openai_api/test_responses.py' \
  --provider=stack:together --model meta-llama/Llama-4-Scout-17B-16E-Instruct
```
This commit is contained in:
Ashwin Bharambe 2025-05-24 14:20:42 -07:00 committed by GitHub
parent 66f09f24ed
commit 3faf1e4a79
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 865 additions and 382 deletions

View file

@ -24,7 +24,7 @@ jobs:
matrix:
# Listing tests manually since some of them currently fail
# TODO: generate matrix list from tests/integration when fixed
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers]
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime]
client-type: [library, http]
fail-fast: false # we want to run all tests regardless of failure
@ -90,7 +90,7 @@ jobs:
else
stack_config="http://localhost:8321"
fi
uv run pytest -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
--text-model="meta-llama/Llama-3.2-3B-Instruct" \
--embedding-model=all-MiniLM-L6-v2

1
.gitignore vendored
View file

@ -6,6 +6,7 @@ dev_requirements.txt
build
.DS_Store
llama_stack/configs/*
.cursor/
xcuserdata/
*.hmap
.DS_Store

View file

@ -7218,15 +7218,15 @@
"OpenAIResponseOutputMessageFunctionToolCall": {
"type": "object",
"properties": {
"arguments": {
"type": "string"
},
"call_id": {
"type": "string"
},
"name": {
"type": "string"
},
"arguments": {
"type": "string"
},
"type": {
"type": "string",
"const": "function_call",
@ -7241,12 +7241,10 @@
},
"additionalProperties": false,
"required": [
"arguments",
"call_id",
"name",
"type",
"id",
"status"
"arguments",
"type"
],
"title": "OpenAIResponseOutputMessageFunctionToolCall"
},
@ -7412,6 +7410,12 @@
},
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
},
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPCall"
},
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
}
],
"discriminator": {
@ -7419,10 +7423,118 @@
"mapping": {
"message": "#/components/schemas/OpenAIResponseMessage",
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall",
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall",
"mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall",
"mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
}
}
},
"OpenAIResponseOutputMessageMCPCall": {
"type": "object",
"properties": {
"id": {
"type": "string"
},
"type": {
"type": "string",
"const": "mcp_call",
"default": "mcp_call"
},
"arguments": {
"type": "string"
},
"name": {
"type": "string"
},
"server_label": {
"type": "string"
},
"error": {
"type": "string"
},
"output": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"id",
"type",
"arguments",
"name",
"server_label"
],
"title": "OpenAIResponseOutputMessageMCPCall"
},
"OpenAIResponseOutputMessageMCPListTools": {
"type": "object",
"properties": {
"id": {
"type": "string"
},
"type": {
"type": "string",
"const": "mcp_list_tools",
"default": "mcp_list_tools"
},
"server_label": {
"type": "string"
},
"tools": {
"type": "array",
"items": {
"type": "object",
"properties": {
"input_schema": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"name": {
"type": "string"
},
"description": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"input_schema",
"name"
],
"title": "MCPListToolsTool"
}
}
},
"additionalProperties": false,
"required": [
"id",
"type",
"server_label",
"tools"
],
"title": "OpenAIResponseOutputMessageMCPListTools"
},
"OpenAIResponseObjectStream": {
"oneOf": [
{

View file

@ -5075,12 +5075,12 @@ components:
"OpenAIResponseOutputMessageFunctionToolCall":
type: object
properties:
arguments:
type: string
call_id:
type: string
name:
type: string
arguments:
type: string
type:
type: string
const: function_call
@ -5091,12 +5091,10 @@ components:
type: string
additionalProperties: false
required:
- arguments
- call_id
- name
- arguments
- type
- id
- status
title: >-
OpenAIResponseOutputMessageFunctionToolCall
"OpenAIResponseOutputMessageWebSearchToolCall":
@ -5214,12 +5212,85 @@ components:
- $ref: '#/components/schemas/OpenAIResponseMessage'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
discriminator:
propertyName: type
mapping:
message: '#/components/schemas/OpenAIResponseMessage'
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
OpenAIResponseOutputMessageMCPCall:
type: object
properties:
id:
type: string
type:
type: string
const: mcp_call
default: mcp_call
arguments:
type: string
name:
type: string
server_label:
type: string
error:
type: string
output:
type: string
additionalProperties: false
required:
- id
- type
- arguments
- name
- server_label
title: OpenAIResponseOutputMessageMCPCall
OpenAIResponseOutputMessageMCPListTools:
type: object
properties:
id:
type: string
type:
type: string
const: mcp_list_tools
default: mcp_list_tools
server_label:
type: string
tools:
type: array
items:
type: object
properties:
input_schema:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
name:
type: string
description:
type: string
additionalProperties: false
required:
- input_schema
- name
title: MCPListToolsTool
additionalProperties: false
required:
- id
- type
- server_label
- tools
title: OpenAIResponseOutputMessageMCPListTools
OpenAIResponseObjectStream:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'

View file

@ -10,6 +10,9 @@ from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type, register_schema
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
# take their YAML and generate this file automatically. Their YAML is available.
@json_schema_type
class OpenAIResponseError(BaseModel):
@ -79,16 +82,45 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
@json_schema_type
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
arguments: str
call_id: str
name: str
arguments: str
type: Literal["function_call"] = "function_call"
id: str | None = None
status: str | None = None
@json_schema_type
class OpenAIResponseOutputMessageMCPCall(BaseModel):
id: str
status: str
type: Literal["mcp_call"] = "mcp_call"
arguments: str
name: str
server_label: str
error: str | None = None
output: str | None = None
class MCPListToolsTool(BaseModel):
input_schema: dict[str, Any]
name: str
description: str | None = None
@json_schema_type
class OpenAIResponseOutputMessageMCPListTools(BaseModel):
id: str
type: Literal["mcp_list_tools"] = "mcp_list_tools"
server_label: str
tools: list[MCPListToolsTool]
OpenAIResponseOutput = Annotated[
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseMessage
| OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools,
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")

View file

@ -14,6 +14,7 @@ from pydantic import BaseModel
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
AllowedToolsFilter,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIResponseInput,
@ -22,7 +23,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
@ -51,11 +52,12 @@ from llama_stack.apis.inference.inference import (
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
logger = get_logger(name=__name__, category="openai_responses")
@ -168,6 +170,15 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
response: OpenAIResponseObject
class ChatCompletionContext(BaseModel):
model: str
messages: list[OpenAIMessageParam]
tools: list[ChatCompletionToolParam] | None = None
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
stream: bool
temperature: float | None
class OpenAIResponsesImpl:
def __init__(
self,
@ -255,13 +266,32 @@ class OpenAIResponsesImpl:
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
):
output_messages: list[OpenAIResponseOutput] = []
stream = False if stream is None else stream
# Huge TODO: we need to run this in a loop, until morale improves
# Create context to run "chat completion"
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
)
if mcp_list_message:
output_messages.append(mcp_list_message)
ctx = ChatCompletionContext(
model=model,
messages=messages,
tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
stream=stream,
temperature=temperature,
)
# Run inference
chat_response = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
@ -270,6 +300,7 @@ class OpenAIResponsesImpl:
temperature=temperature,
)
# Collect output
if stream:
# TODO: refactor this into a separate method that handles streaming
chat_response_id = ""
@ -328,11 +359,11 @@ class OpenAIResponsesImpl:
# dump and reload to map to our pydantic types
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
output_messages: list[OpenAIResponseOutput] = []
# Execute tool calls if any
for choice in chat_response.choices:
if choice.message.tool_calls and tools:
# Assume if the first tool is a function, all tools are functions
if isinstance(tools[0], OpenAIResponseInputToolFunction):
if tools[0].type == "function":
for tool_call in choice.message.tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
@ -344,11 +375,12 @@ class OpenAIResponsesImpl:
)
)
else:
output_messages.extend(
await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature)
)
tool_messages = await self._execute_tool_and_return_final_output(choice, ctx)
output_messages.extend(tool_messages)
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
# Create response object
response = OpenAIResponseObject(
created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}",
@ -359,9 +391,8 @@ class OpenAIResponsesImpl:
)
logger.debug(f"OpenAI Responses response: {response}")
# Store response if requested
if store:
# Store in kvstore
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
@ -403,15 +434,20 @@ class OpenAIResponsesImpl:
async def _convert_response_tools_to_chat_tools(
self, tools: list[OpenAIResponseInputTool]
) -> list[ChatCompletionToolParam]:
chat_tools: list[ChatCompletionToolParam] = []
for input_tool in tools:
# TODO: Handle other tool types
if input_tool.type == "function":
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
elif input_tool.type == "web_search":
tool_name = "web_search"
tool = await self.tool_groups_api.get_tool(tool_name)
) -> tuple[
list[ChatCompletionToolParam],
dict[str, OpenAIResponseInputToolMCP],
OpenAIResponseOutput | None,
]:
from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
OpenAIResponseOutputMessageMCPListTools,
)
from llama_stack.apis.tools.tools import Tool
mcp_tool_to_server = {}
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
@ -425,78 +461,106 @@ class OpenAIResponsesImpl:
for param in tool.parameters
},
)
chat_tool = convert_tooldef_to_openai_tool(tool_def)
chat_tools.append(chat_tool)
return convert_tooldef_to_openai_tool(tool_def)
mcp_list_message = None
chat_tools: list[ChatCompletionToolParam] = []
for input_tool in tools:
# TODO: Handle other tool types
if input_tool.type == "function":
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
elif input_tool.type == "web_search":
tool_name = "web_search"
tool = await self.tool_groups_api.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "mcp":
always_allowed = None
never_allowed = None
if input_tool.allowed_tools:
if isinstance(input_tool.allowed_tools, list):
always_allowed = input_tool.allowed_tools
elif isinstance(input_tool.allowed_tools, AllowedToolsFilter):
always_allowed = input_tool.allowed_tools.always
never_allowed = input_tool.allowed_tools.never
tool_defs = await list_mcp_tools(
endpoint=input_tool.server_url,
headers=input_tool.headers or {},
)
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",
status="completed",
server_label=input_tool.server_label,
tools=[],
)
for t in tool_defs.data:
if never_allowed and t.name in never_allowed:
continue
if not always_allowed or t.name in always_allowed:
chat_tools.append(make_openai_tool(t.name, t))
if t.name in mcp_tool_to_server:
raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}")
mcp_tool_to_server[t.name] = input_tool
mcp_list_message.tools.append(
MCPListToolsTool(
name=t.name,
description=t.description,
input_schema={
"type": "object",
"properties": {
p.name: {
"type": p.parameter_type,
"description": p.description,
}
for p in t.parameters
},
"required": [p.name for p in t.parameters if p.required],
},
)
)
else:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools
return chat_tools, mcp_tool_to_server, mcp_list_message
async def _execute_tool_and_return_final_output(
self,
model_id: str,
stream: bool,
choice: OpenAIChoice,
messages: list[OpenAIMessageParam],
temperature: float,
ctx: ChatCompletionContext,
) -> list[OpenAIResponseOutput]:
output_messages: list[OpenAIResponseOutput] = []
# If the choice is not an assistant message, we don't need to execute any tools
if not isinstance(choice.message, OpenAIAssistantMessageParam):
return output_messages
# If the assistant message doesn't have any tool calls, we don't need to execute any tools
if not choice.message.tool_calls:
return output_messages
# Copy the messages list to avoid mutating the original list
messages = messages.copy()
next_turn_messages = ctx.messages.copy()
# Add the assistant message with tool_calls response to the messages list
messages.append(choice.message)
next_turn_messages.append(choice.message)
for tool_call in choice.message.tool_calls:
tool_call_id = tool_call.id
function = tool_call.function
# If for some reason the tool call doesn't have a function or id, we can't execute it
if not function or not tool_call_id:
continue
# TODO: telemetry spans for tool calls
result = await self._execute_tool_call(function)
# Handle tool call failure
if not result:
output_messages.append(
OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="failed",
)
)
continue
output_messages.append(
OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="completed",
),
)
result_content = ""
# TODO: handle other result content types and lists
if isinstance(result.content, str):
result_content = result.content
messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id))
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
if tool_call_log:
output_messages.append(tool_call_log)
if further_input:
next_turn_messages.append(further_input)
tool_results_chat_response = await self.inference_api.openai_chat_completion(
model=model_id,
messages=messages,
stream=stream,
temperature=temperature,
model=ctx.model,
messages=next_turn_messages,
stream=ctx.stream,
temperature=ctx.temperature,
)
# type cast to appease mypy
# type cast to appease mypy: this is needed because we don't handle streaming properly :)
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
# Huge TODO: these are NOT the final outputs, we must keep the loop going
tool_final_outputs = [
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
]
@ -506,15 +570,86 @@ class OpenAIResponsesImpl:
async def _execute_tool_call(
self,
function: OpenAIChatCompletionToolCallFunction,
) -> ToolInvocationResult | None:
if not function.name:
return None
function_args = json.loads(function.arguments) if function.arguments else {}
logger.info(f"executing tool call: {function.name} with args: {function_args}")
tool_call: OpenAIChatCompletionToolCall,
ctx: ChatCompletionContext,
) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
tool_call_id = tool_call.id
function = tool_call.function
if not function or not tool_call_id or not function.name:
return None, None
error_exc = None
result = None
try:
if function.name in ctx.mcp_tool_to_server:
mcp_tool = ctx.mcp_tool_to_server[function.name]
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function.name,
kwargs=json.loads(function.arguments) if function.arguments else {},
)
else:
result = await self.tool_runtime_api.invoke_tool(
tool_name=function.name,
kwargs=function_args,
kwargs=json.loads(function.arguments) if function.arguments else {},
)
logger.debug(f"tool call {function.name} completed with result: {result}")
return result
except Exception as e:
error_exc = e
if function.name in ctx.mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall
message = OpenAIResponseOutputMessageMCPCall(
id=tool_call_id,
arguments=function.arguments,
name=function.name,
server_label=ctx.mcp_tool_to_server[function.name].server_label,
)
if error_exc:
message.error = str(error_exc)
elif (result.error_code and result.error_code > 0) or result.error_message:
message.error = f"Error (code {result.error_code}): {result.error_message}"
elif result.content:
message.output = interleaved_content_as_str(result.content)
else:
if function.name == "web_search":
message = OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="completed",
)
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
message.status = "failed"
else:
raise ValueError(f"Unknown tool {function.name} called")
input_message = None
if result and result.content:
if isinstance(result.content, str):
content = result.content
elif isinstance(result.content, list):
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
content = []
for item in result.content:
if isinstance(item, TextContentItem):
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
elif isinstance(item, ImageContentItem):
if item.image.data:
url = f"data:image;base64,{item.image.data}"
else:
url = item.image.url
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
else:
raise ValueError(f"Unknown result content type: {type(item)}")
content.append(part)
else:
raise ValueError(f"Unknown result content type: {type(result.content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
return message, input_message

View file

@ -4,53 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urlparse
import exceptiongroup
import httpx
from mcp import ClientSession
from mcp import types as mcp_types
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools")
@asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
try:
async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session
except BaseException as e:
if isinstance(e, exceptiongroup.BaseExceptionGroup):
for exc in e.exceptions:
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
raise AuthenticationRequiredError(e) from e
raise
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config
@ -64,32 +37,8 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
# this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
headers = await self.get_headers_from_request(mcp_endpoint.uri)
tools = []
async with sse_client_wrapper(mcp_endpoint.uri, headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"endpoint": mcp_endpoint.uri,
},
)
)
return ListToolDefsResponse(data=tools)
return await list_mcp_tools(mcp_endpoint.uri, headers)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
@ -100,23 +49,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
headers = await self.get_headers_from_request(endpoint)
async with sse_client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool.identifier, kwargs)
content = []
for item in result.content:
if isinstance(item, mcp_types.TextContent):
content.append(TextContentItem(text=item.text))
elif isinstance(item, mcp_types.ImageContent):
content.append(ImageContentItem(image=item.data))
elif isinstance(item, mcp_types.EmbeddedResource):
logger.warning(f"EmbeddedResource is not supported: {item}")
else:
raise ValueError(f"Unknown content type: {type(item)}")
return ToolInvocationResult(
content=content,
error_code=1 if result.isError else 0,
)
return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
def canonicalize_uri(uri: str) -> str:
@ -129,9 +62,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
for entry in values:
parts = entry.split(":")
if len(parts) == 2:
k, v = parts
headers[k.strip()] = v.strip()
headers.update(convert_header_list_to_dict(values))
return headers

View file

@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from contextlib import asynccontextmanager
from typing import Any
try:
# for python < 3.11
import exceptiongroup
BaseExceptionGroup = exceptiongroup.BaseExceptionGroup
except ImportError:
pass
import httpx
from mcp import ClientSession
from mcp import types as mcp_types
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef,
ToolInvocationResult,
ToolParameter,
)
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.log import get_logger
logger = get_logger(__name__, category="tools")
@asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
try:
async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session
except BaseException as e:
if isinstance(e, BaseExceptionGroup):
for exc in e.exceptions:
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
raise AuthenticationRequiredError(e) from e
raise
def convert_header_list_to_dict(header_list: list[str]) -> dict[str, str]:
headers = {}
for header in header_list:
parts = header.split(":")
if len(parts) == 2:
k, v = parts
headers[k.strip()] = v.strip()
return headers
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
tools = []
async with sse_client_wrapper(endpoint, headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"endpoint": endpoint,
},
)
)
return ListToolDefsResponse(data=tools)
async def invoke_mcp_tool(
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
) -> ToolInvocationResult:
async with sse_client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool_name, kwargs)
content: list[InterleavedContentItem] = []
for item in result.content:
if isinstance(item, mcp_types.TextContent):
content.append(TextContentItem(text=item.text))
elif isinstance(item, mcp_types.ImageContent):
content.append(ImageContentItem(image=item.data))
elif isinstance(item, mcp_types.EmbeddedResource):
logger.warning(f"EmbeddedResource is not supported: {item}")
else:
raise ValueError(f"Unknown content type: {type(item)}")
return ToolInvocationResult(
content=content,
error_code=1 if result.isError else 0,
)

View file

@ -67,6 +67,7 @@ unit = [
"aiosqlite",
"aiohttp",
"pypdf",
"mcp",
"chardet",
"qdrant-client",
"opentelemetry-exporter-otlp-proto-http",

145
tests/common/mcp.py Normal file
View file

@ -0,0 +1,145 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# we want the mcp server to be authenticated OR not, depends
from contextlib import contextmanager
# Unfortunately the toolgroup id must be tied to the tool names because the registry
# indexes on both toolgroups and tools independently (and not jointly). That really
# needs to be fixed.
MCP_TOOLGROUP_ID = "mcp::localmcp"
@contextmanager
def make_mcp_server(required_auth_token: str | None = None):
import threading
import time
import httpx
import uvicorn
from mcp import types
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.responses import Response
from starlette.routing import Mount, Route
server = FastMCP("FastMCP Test Server", log_level="WARNING")
@server.tool()
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
@server.tool()
async def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
"""
Returns the boiling point of a liquid in Celcius or Fahrenheit.
:param liquid_name: The name of the liquid
:param celcius: Whether to return the boiling point in Celcius
:return: The boiling point of the liquid in Celcius or Fahrenheit
"""
if liquid_name.lower() == "polyjuice":
if celcius:
return -100
else:
return -212
else:
return -1
sse = SseServerTransport("/messages/")
async def handle_sse(request):
from starlette.exceptions import HTTPException
auth_header = request.headers.get("Authorization")
auth_token = None
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.split(" ")[1]
if required_auth_token and auth_token != required_auth_token:
raise HTTPException(status_code=401, detail="Unauthorized")
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await server._mcp_server.run(
streams[0],
streams[1],
server._mcp_server.create_initialization_options(),
)
return Response()
app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
def get_open_port():
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
port = get_open_port()
# make uvicorn logs be less verbose
config = uvicorn.Config(app, host="0.0.0.0", port=port, log_level="warning")
server_instance = uvicorn.Server(config)
app.state.uvicorn_server = server_instance
def run_server():
server_instance.run()
# Start the server in a new thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Polling until the server is ready
timeout = 10
start_time = time.time()
server_url = f"http://localhost:{port}/sse"
while time.time() - start_time < timeout:
try:
response = httpx.get(server_url)
if response.status_code in [200, 401]:
break
except httpx.RequestError:
pass
time.sleep(0.1)
try:
yield {"server_url": server_url}
finally:
print("Telling SSE server to exit")
server_instance.should_exit = True
time.sleep(0.5)
# Force shutdown if still running
if server_thread.is_alive():
try:
if hasattr(server_instance, "servers") and server_instance.servers:
for srv in server_instance.servers:
srv.close()
# Wait for graceful shutdown
server_thread.join(timeout=3)
if server_thread.is_alive():
print("Warning: Server thread still alive after shutdown attempt")
except Exception as e:
print(f"Error during server shutdown: {e}")
# CRITICAL: Reset SSE global state to prevent event loop contamination
# Reset the SSE AppStatus singleton that stores anyio.Event objects
from sse_starlette.sse import AppStatus
AppStatus.should_exit = False
AppStatus.should_exit_event = None
print("SSE server exited")

View file

@ -5,117 +5,43 @@
# the root directory of this source tree.
import json
import socket
import threading
import time
import httpx
import mcp.types as types
import pytest
import uvicorn
from llama_stack_client import Agent
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.responses import Response
from starlette.routing import Mount, Route
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import AuthenticationRequiredError
AUTH_TOKEN = "test-token"
from tests.common.mcp import MCP_TOOLGROUP_ID, make_mcp_server
@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def mcp_server():
server = FastMCP("FastMCP Test Server")
@server.tool()
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
sse = SseServerTransport("/messages/")
async def handle_sse(request):
auth_header = request.headers.get("Authorization")
auth_token = None
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.split(" ")[1]
if auth_token != AUTH_TOKEN:
raise HTTPException(status_code=401, detail="Unauthorized")
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await server._mcp_server.run(
streams[0],
streams[1],
server._mcp_server.create_initialization_options(),
)
return Response()
app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
port = get_open_port()
config = uvicorn.Config(app, host="0.0.0.0", port=port)
server_instance = uvicorn.Server(config)
app.state.uvicorn_server = server_instance
def run_server():
server_instance.run()
# Start the server in a new thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Polling until the server is ready
timeout = 10
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = httpx.get(f"http://localhost:{port}/sse")
if response.status_code == 401:
break
except httpx.RequestError:
pass
time.sleep(0.1)
yield port
# Tell server to exit
server_instance.should_exit = True
server_thread.join(timeout=5)
with make_mcp_server(required_auth_token=AUTH_TOKEN) as mcp_server_info:
yield mcp_server_info
def test_mcp_invocation(llama_stack_client, mcp_server):
port = mcp_server
test_toolgroup_id = "remote::mcptest"
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
pytest.skip("The local MCP server only reliably reachable from library client.")
test_toolgroup_id = MCP_TOOLGROUP_ID
uri = mcp_server["server_url"]
# registering itself should fail since it requires listing tools
with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
mcp_endpoint=dict(uri=uri),
)
provider_data = {
"mcp_headers": {
f"http://localhost:{port}/sse": [
uri: [
f"Authorization: Bearer {AUTH_TOKEN}",
],
},
@ -133,24 +59,18 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
mcp_endpoint=dict(uri=uri),
extra_headers=auth_headers,
)
response = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers,
)
assert len(response) == 1
assert response[0].identifier == "greet_everyone"
assert response[0].type == "tool"
assert len(response[0].parameters) == 1
p = response[0].parameters[0]
assert p.name == "url"
assert p.parameter_type == "string"
assert p.required
assert len(response) == 2
assert {t.identifier for t in response} == {"greet_everyone", "get_boiling_point"}
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name=response[0].identifier,
tool_name="greet_everyone",
kwargs=dict(url="https://www.google.com"),
extra_headers=auth_headers,
)
@ -159,7 +79,9 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
assert content[0].type == "text"
assert content[0].text == "Hello, world!"
models = llama_stack_client.models.list()
models = [
m for m in llama_stack_client.models.list() if m.model_type == ModelType.llm and "guard" not in m.identifier
]
model_id = models[0].identifier
print(f"Using model: {model_id}")
agent = Agent(
@ -174,7 +96,7 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
messages=[
{
"role": "user",
"content": "Yo. Use tools.",
"content": "Say hi to the world. Use tools to do so.",
}
],
stream=False,
@ -196,7 +118,6 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
third = steps[2]
assert third.step_type == "inference"
assert len(third.api_model_response.tool_calls) == 0
# when streaming, we currently don't check auth headers upfront and fail the request
# early. but we should at least be generating a 401 later in the process.
@ -205,7 +126,7 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
messages=[
{
"role": "user",
"content": "Yo. Use tools.",
"content": "What is the boiling point of polyjuice? Use tools to answer.",
}
],
stream=True,

View file

@ -4,88 +4,22 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import socket
import threading
import time
import httpx
import mcp.types as types
import pytest
import uvicorn
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.routing import Mount, Route
from llama_stack import LlamaStackAsLibraryClient
from tests.common.mcp import MCP_TOOLGROUP_ID, make_mcp_server
@pytest.fixture(scope="module")
def mcp_server():
server = FastMCP("FastMCP Test Server")
def test_register_and_unregister_toolgroup(llama_stack_client):
# TODO: make this work for http client also but you need to ensure
# the MCP server is reachable from llama stack server
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
pytest.skip("The local MCP server only reliably reachable from library client.")
@server.tool()
async def fetch(url: str, ctx: Context) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
headers = {"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"}
async with httpx.AsyncClient(follow_redirects=True, headers=headers) as client:
response = await client.get(url)
response.raise_for_status()
return [types.TextContent(type="text", text=response.text)]
sse = SseServerTransport("/messages/")
async def handle_sse(request):
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await server._mcp_server.run(
streams[0],
streams[1],
server._mcp_server.create_initialization_options(),
)
app = Starlette(
debug=True,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
port = get_open_port()
def run_server():
uvicorn.run(app, host="0.0.0.0", port=port)
# Start the server in a new thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Polling until the server is ready
timeout = 10
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = httpx.get(f"http://localhost:{port}/sse")
if response.status_code == 200:
break
except (httpx.RequestError, httpx.HTTPStatusError):
pass
time.sleep(0.1)
yield port
def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server):
"""
Integration test for registering and unregistering a toolgroup using the ToolGroups API.
"""
port = mcp_server
test_toolgroup_id = "remote::web-fetch"
test_toolgroup_id = MCP_TOOLGROUP_ID
provider_id = "model-context-protocol"
with make_mcp_server() as mcp_server_info:
# Cleanup before running the test
toolgroups = llama_stack_client.toolgroups.list()
for toolgroup in toolgroups:
@ -96,7 +30,7 @@ def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server):
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id=provider_id,
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
mcp_endpoint=dict(uri=mcp_server_info["server_url"]),
)
# Verify registration

View file

@ -31,6 +31,18 @@ test_response_web_search:
search_context_size: "low"
output: "128"
test_response_mcp_tool:
test_name: test_response_mcp_tool
test_params:
case:
- case_id: "boiling_point_tool"
input: "What is the boiling point of polyjuice?"
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
output: "Hello, world!"
test_response_custom_tool:
test_name: test_response_custom_tool
test_params:

View file

@ -4,9 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import httpx
import pytest
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from tests.common.mcp import make_mcp_server
from tests.verifications.openai_api.fixtures.fixtures import (
case_id_generator,
get_base_test_name,
@ -124,6 +129,79 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid
assert case["output"].lower() in response.output_text.lower().strip()
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_mcp_tool"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_non_streaming_mcp_tool(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
with make_mcp_server() as mcp_server_info:
tools = case["tools"]
for tool in tools:
if tool["type"] == "mcp":
tool["server_url"] = mcp_server_info["server_url"]
response = openai_client.responses.create(
model=model,
input=case["input"],
tools=tools,
stream=False,
)
assert len(response.output) >= 3
list_tools = response.output[0]
assert list_tools.type == "mcp_list_tools"
assert list_tools.server_label == "localmcp"
assert len(list_tools.tools) == 2
assert {t["name"] for t in list_tools.tools} == {"get_boiling_point", "greet_everyone"}
call = response.output[1]
assert call.type == "mcp_call"
assert call.name == "get_boiling_point"
assert json.loads(call.arguments) == {"liquid_name": "polyjuice", "celcius": True}
assert call.error is None
assert "-100" in call.output
message = response.output[2]
text_content = message.content[0].text
assert "boiling point" in text_content.lower()
with make_mcp_server(required_auth_token="test-token") as mcp_server_info:
tools = case["tools"]
for tool in tools:
if tool["type"] == "mcp":
tool["server_url"] = mcp_server_info["server_url"]
exc_type = (
AuthenticationRequiredError
if isinstance(openai_client, LlamaStackAsLibraryClient)
else httpx.HTTPStatusError
)
with pytest.raises(exc_type):
openai_client.responses.create(
model=model,
input=case["input"],
tools=tools,
stream=False,
)
for tool in tools:
if tool["type"] == "mcp":
tool["server_url"] = mcp_server_info["server_url"]
tool["headers"] = {"Authorization": "Bearer test-token"}
response = openai_client.responses.create(
model=model,
input=case["input"],
tools=tools,
stream=False,
)
assert len(response.output) >= 3
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_custom_tool"]["test_params"]["case"],

2
uv.lock generated
View file

@ -1544,6 +1544,7 @@ unit = [
{ name = "aiohttp" },
{ name = "aiosqlite" },
{ name = "chardet" },
{ name = "mcp" },
{ name = "openai" },
{ name = "opentelemetry-exporter-otlp-proto-http" },
{ name = "pypdf" },
@ -1576,6 +1577,7 @@ requires-dist = [
{ name = "llama-stack-client", specifier = ">=0.2.7" },
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.7" },
{ name = "mcp", marker = "extra == 'test'" },
{ name = "mcp", marker = "extra == 'unit'" },
{ name = "myst-parser", marker = "extra == 'docs'" },
{ name = "nbval", marker = "extra == 'dev'" },
{ name = "openai", specifier = ">=1.66" },