feat: enable MCP execution in Responses impl (#2240)

## Test Plan

```
pytest -s -v 'tests/verifications/openai_api/test_responses.py' \
  --provider=stack:together --model meta-llama/Llama-4-Scout-17B-16E-Instruct
```
This commit is contained in:
Ashwin Bharambe 2025-05-24 14:20:42 -07:00 committed by GitHub
parent 66f09f24ed
commit 3faf1e4a79
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 865 additions and 382 deletions

View file

@ -10,6 +10,9 @@ from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type, register_schema
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
# take their YAML and generate this file automatically. Their YAML is available.
@json_schema_type
class OpenAIResponseError(BaseModel):
@ -79,16 +82,45 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
@json_schema_type
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
arguments: str
call_id: str
name: str
arguments: str
type: Literal["function_call"] = "function_call"
id: str | None = None
status: str | None = None
@json_schema_type
class OpenAIResponseOutputMessageMCPCall(BaseModel):
id: str
status: str
type: Literal["mcp_call"] = "mcp_call"
arguments: str
name: str
server_label: str
error: str | None = None
output: str | None = None
class MCPListToolsTool(BaseModel):
input_schema: dict[str, Any]
name: str
description: str | None = None
@json_schema_type
class OpenAIResponseOutputMessageMCPListTools(BaseModel):
id: str
type: Literal["mcp_list_tools"] = "mcp_list_tools"
server_label: str
tools: list[MCPListToolsTool]
OpenAIResponseOutput = Annotated[
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseMessage
| OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools,
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")

View file

@ -14,6 +14,7 @@ from pydantic import BaseModel
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
AllowedToolsFilter,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIResponseInput,
@ -22,7 +23,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
@ -51,11 +52,12 @@ from llama_stack.apis.inference.inference import (
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
logger = get_logger(name=__name__, category="openai_responses")
@ -168,6 +170,15 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
response: OpenAIResponseObject
class ChatCompletionContext(BaseModel):
model: str
messages: list[OpenAIMessageParam]
tools: list[ChatCompletionToolParam] | None = None
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
stream: bool
temperature: float | None
class OpenAIResponsesImpl:
def __init__(
self,
@ -255,13 +266,32 @@ class OpenAIResponsesImpl:
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
):
output_messages: list[OpenAIResponseOutput] = []
stream = False if stream is None else stream
# Huge TODO: we need to run this in a loop, until morale improves
# Create context to run "chat completion"
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
)
if mcp_list_message:
output_messages.append(mcp_list_message)
ctx = ChatCompletionContext(
model=model,
messages=messages,
tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
stream=stream,
temperature=temperature,
)
# Run inference
chat_response = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
@ -270,6 +300,7 @@ class OpenAIResponsesImpl:
temperature=temperature,
)
# Collect output
if stream:
# TODO: refactor this into a separate method that handles streaming
chat_response_id = ""
@ -328,11 +359,11 @@ class OpenAIResponsesImpl:
# dump and reload to map to our pydantic types
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
output_messages: list[OpenAIResponseOutput] = []
# Execute tool calls if any
for choice in chat_response.choices:
if choice.message.tool_calls and tools:
# Assume if the first tool is a function, all tools are functions
if isinstance(tools[0], OpenAIResponseInputToolFunction):
if tools[0].type == "function":
for tool_call in choice.message.tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
@ -344,11 +375,12 @@ class OpenAIResponsesImpl:
)
)
else:
output_messages.extend(
await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature)
)
tool_messages = await self._execute_tool_and_return_final_output(choice, ctx)
output_messages.extend(tool_messages)
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
# Create response object
response = OpenAIResponseObject(
created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}",
@ -359,9 +391,8 @@ class OpenAIResponsesImpl:
)
logger.debug(f"OpenAI Responses response: {response}")
# Store response if requested
if store:
# Store in kvstore
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
@ -403,7 +434,36 @@ class OpenAIResponsesImpl:
async def _convert_response_tools_to_chat_tools(
self, tools: list[OpenAIResponseInputTool]
) -> list[ChatCompletionToolParam]:
) -> tuple[
list[ChatCompletionToolParam],
dict[str, OpenAIResponseInputToolMCP],
OpenAIResponseOutput | None,
]:
from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
OpenAIResponseOutputMessageMCPListTools,
)
from llama_stack.apis.tools.tools import Tool
mcp_tool_to_server = {}
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool.parameters
},
)
return convert_tooldef_to_openai_tool(tool_def)
mcp_list_message = None
chat_tools: list[ChatCompletionToolParam] = []
for input_tool in tools:
# TODO: Handle other tool types
@ -412,91 +472,95 @@ class OpenAIResponsesImpl:
elif input_tool.type == "web_search":
tool_name = "web_search"
tool = await self.tool_groups_api.get_tool(tool_name)
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool.parameters
},
if not tool:
raise ValueError(f"Tool {tool_name} not found")
chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "mcp":
always_allowed = None
never_allowed = None
if input_tool.allowed_tools:
if isinstance(input_tool.allowed_tools, list):
always_allowed = input_tool.allowed_tools
elif isinstance(input_tool.allowed_tools, AllowedToolsFilter):
always_allowed = input_tool.allowed_tools.always
never_allowed = input_tool.allowed_tools.never
tool_defs = await list_mcp_tools(
endpoint=input_tool.server_url,
headers=input_tool.headers or {},
)
chat_tool = convert_tooldef_to_openai_tool(tool_def)
chat_tools.append(chat_tool)
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",
status="completed",
server_label=input_tool.server_label,
tools=[],
)
for t in tool_defs.data:
if never_allowed and t.name in never_allowed:
continue
if not always_allowed or t.name in always_allowed:
chat_tools.append(make_openai_tool(t.name, t))
if t.name in mcp_tool_to_server:
raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}")
mcp_tool_to_server[t.name] = input_tool
mcp_list_message.tools.append(
MCPListToolsTool(
name=t.name,
description=t.description,
input_schema={
"type": "object",
"properties": {
p.name: {
"type": p.parameter_type,
"description": p.description,
}
for p in t.parameters
},
"required": [p.name for p in t.parameters if p.required],
},
)
)
else:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools
return chat_tools, mcp_tool_to_server, mcp_list_message
async def _execute_tool_and_return_final_output(
self,
model_id: str,
stream: bool,
choice: OpenAIChoice,
messages: list[OpenAIMessageParam],
temperature: float,
ctx: ChatCompletionContext,
) -> list[OpenAIResponseOutput]:
output_messages: list[OpenAIResponseOutput] = []
# If the choice is not an assistant message, we don't need to execute any tools
if not isinstance(choice.message, OpenAIAssistantMessageParam):
return output_messages
# If the assistant message doesn't have any tool calls, we don't need to execute any tools
if not choice.message.tool_calls:
return output_messages
# Copy the messages list to avoid mutating the original list
messages = messages.copy()
next_turn_messages = ctx.messages.copy()
# Add the assistant message with tool_calls response to the messages list
messages.append(choice.message)
next_turn_messages.append(choice.message)
for tool_call in choice.message.tool_calls:
tool_call_id = tool_call.id
function = tool_call.function
# If for some reason the tool call doesn't have a function or id, we can't execute it
if not function or not tool_call_id:
continue
# TODO: telemetry spans for tool calls
result = await self._execute_tool_call(function)
# Handle tool call failure
if not result:
output_messages.append(
OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="failed",
)
)
continue
output_messages.append(
OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="completed",
),
)
result_content = ""
# TODO: handle other result content types and lists
if isinstance(result.content, str):
result_content = result.content
messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id))
tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
if tool_call_log:
output_messages.append(tool_call_log)
if further_input:
next_turn_messages.append(further_input)
tool_results_chat_response = await self.inference_api.openai_chat_completion(
model=model_id,
messages=messages,
stream=stream,
temperature=temperature,
model=ctx.model,
messages=next_turn_messages,
stream=ctx.stream,
temperature=ctx.temperature,
)
# type cast to appease mypy
# type cast to appease mypy: this is needed because we don't handle streaming properly :)
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
# Huge TODO: these are NOT the final outputs, we must keep the loop going
tool_final_outputs = [
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
]
@ -506,15 +570,86 @@ class OpenAIResponsesImpl:
async def _execute_tool_call(
self,
function: OpenAIChatCompletionToolCallFunction,
) -> ToolInvocationResult | None:
if not function.name:
return None
function_args = json.loads(function.arguments) if function.arguments else {}
logger.info(f"executing tool call: {function.name} with args: {function_args}")
result = await self.tool_runtime_api.invoke_tool(
tool_name=function.name,
kwargs=function_args,
tool_call: OpenAIChatCompletionToolCall,
ctx: ChatCompletionContext,
) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
logger.debug(f"tool call {function.name} completed with result: {result}")
return result
tool_call_id = tool_call.id
function = tool_call.function
if not function or not tool_call_id or not function.name:
return None, None
error_exc = None
result = None
try:
if function.name in ctx.mcp_tool_to_server:
mcp_tool = ctx.mcp_tool_to_server[function.name]
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function.name,
kwargs=json.loads(function.arguments) if function.arguments else {},
)
else:
result = await self.tool_runtime_api.invoke_tool(
tool_name=function.name,
kwargs=json.loads(function.arguments) if function.arguments else {},
)
except Exception as e:
error_exc = e
if function.name in ctx.mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall
message = OpenAIResponseOutputMessageMCPCall(
id=tool_call_id,
arguments=function.arguments,
name=function.name,
server_label=ctx.mcp_tool_to_server[function.name].server_label,
)
if error_exc:
message.error = str(error_exc)
elif (result.error_code and result.error_code > 0) or result.error_message:
message.error = f"Error (code {result.error_code}): {result.error_message}"
elif result.content:
message.output = interleaved_content_as_str(result.content)
else:
if function.name == "web_search":
message = OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="completed",
)
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
message.status = "failed"
else:
raise ValueError(f"Unknown tool {function.name} called")
input_message = None
if result and result.content:
if isinstance(result.content, str):
content = result.content
elif isinstance(result.content, list):
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
content = []
for item in result.content:
if isinstance(item, TextContentItem):
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
elif isinstance(item, ImageContentItem):
if item.image.data:
url = f"data:image;base64,{item.image.data}"
else:
url = item.image.url
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
else:
raise ValueError(f"Unknown result content type: {type(item)}")
content.append(part)
else:
raise ValueError(f"Unknown result content type: {type(result.content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
return message, input_message

View file

@ -4,53 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urlparse
import exceptiongroup
import httpx
from mcp import ClientSession
from mcp import types as mcp_types
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools
from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools")
@asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
try:
async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session
except BaseException as e:
if isinstance(e, exceptiongroup.BaseExceptionGroup):
for exc in e.exceptions:
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
raise AuthenticationRequiredError(e) from e
raise
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config
@ -64,32 +37,8 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
# this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
headers = await self.get_headers_from_request(mcp_endpoint.uri)
tools = []
async with sse_client_wrapper(mcp_endpoint.uri, headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"endpoint": mcp_endpoint.uri,
},
)
)
return ListToolDefsResponse(data=tools)
return await list_mcp_tools(mcp_endpoint.uri, headers)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
@ -100,23 +49,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
headers = await self.get_headers_from_request(endpoint)
async with sse_client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool.identifier, kwargs)
content = []
for item in result.content:
if isinstance(item, mcp_types.TextContent):
content.append(TextContentItem(text=item.text))
elif isinstance(item, mcp_types.ImageContent):
content.append(ImageContentItem(image=item.data))
elif isinstance(item, mcp_types.EmbeddedResource):
logger.warning(f"EmbeddedResource is not supported: {item}")
else:
raise ValueError(f"Unknown content type: {type(item)}")
return ToolInvocationResult(
content=content,
error_code=1 if result.isError else 0,
)
return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
def canonicalize_uri(uri: str) -> str:
@ -129,9 +62,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
for entry in values:
parts = entry.split(":")
if len(parts) == 2:
k, v = parts
headers[k.strip()] = v.strip()
headers.update(convert_header_list_to_dict(values))
return headers

View file

@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from contextlib import asynccontextmanager
from typing import Any
try:
# for python < 3.11
import exceptiongroup
BaseExceptionGroup = exceptiongroup.BaseExceptionGroup
except ImportError:
pass
import httpx
from mcp import ClientSession
from mcp import types as mcp_types
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef,
ToolInvocationResult,
ToolParameter,
)
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.log import get_logger
logger = get_logger(__name__, category="tools")
@asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
try:
async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session
except BaseException as e:
if isinstance(e, BaseExceptionGroup):
for exc in e.exceptions:
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401:
raise AuthenticationRequiredError(e) from e
raise
def convert_header_list_to_dict(header_list: list[str]) -> dict[str, str]:
headers = {}
for header in header_list:
parts = header.split(":")
if len(parts) == 2:
k, v = parts
headers[k.strip()] = v.strip()
return headers
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
tools = []
async with sse_client_wrapper(endpoint, headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"endpoint": endpoint,
},
)
)
return ListToolDefsResponse(data=tools)
async def invoke_mcp_tool(
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
) -> ToolInvocationResult:
async with sse_client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool_name, kwargs)
content: list[InterleavedContentItem] = []
for item in result.content:
if isinstance(item, mcp_types.TextContent):
content.append(TextContentItem(text=item.text))
elif isinstance(item, mcp_types.ImageContent):
content.append(ImageContentItem(image=item.data))
elif isinstance(item, mcp_types.EmbeddedResource):
logger.warning(f"EmbeddedResource is not supported: {item}")
else:
raise ValueError(f"Unknown content type: {type(item)}")
return ToolInvocationResult(
content=content,
error_code=1 if result.isError else 0,
)