diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 2414522a7..d78e82c9d 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -24,7 +24,7 @@ jobs: matrix: # Listing tests manually since some of them currently fail # TODO: generate matrix list from tests/integration when fixed - test-type: [agents, inference, datasets, inspect, scoring, post_training, providers] + test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime] client-type: [library, http] fail-fast: false # we want to run all tests regardless of failure @@ -90,7 +90,7 @@ jobs: else stack_config="http://localhost:8321" fi - uv run pytest -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ + uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ --text-model="meta-llama/Llama-3.2-3B-Instruct" \ --embedding-model=all-MiniLM-L6-v2 diff --git a/.gitignore b/.gitignore index 0ef25cdf1..2cc885604 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ dev_requirements.txt build .DS_Store llama_stack/configs/* +.cursor/ xcuserdata/ *.hmap .DS_Store diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 7a6e82001..99ae1c038 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -7218,15 +7218,15 @@ "OpenAIResponseOutputMessageFunctionToolCall": { "type": "object", "properties": { - "arguments": { - "type": "string" - }, "call_id": { "type": "string" }, "name": { "type": "string" }, + "arguments": { + "type": "string" + }, "type": { "type": "string", "const": "function_call", @@ -7241,12 +7241,10 @@ }, "additionalProperties": false, "required": [ - "arguments", "call_id", "name", - "type", - "id", - "status" + "arguments", + "type" ], "title": "OpenAIResponseOutputMessageFunctionToolCall" }, @@ -7412,6 +7410,12 @@ }, { "$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" + }, + { + "$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPCall" + }, + { + "$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" } ], "discriminator": { @@ -7419,10 +7423,118 @@ "mapping": { "message": "#/components/schemas/OpenAIResponseMessage", "web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall", - "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" + "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall", + "mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall", + "mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools" } } }, + "OpenAIResponseOutputMessageMCPCall": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "mcp_call", + "default": "mcp_call" + }, + "arguments": { + "type": "string" + }, + "name": { + "type": "string" + }, + "server_label": { + "type": "string" + }, + "error": { + "type": "string" + }, + "output": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "id", + "type", + "arguments", + "name", + "server_label" + ], + "title": "OpenAIResponseOutputMessageMCPCall" + }, + "OpenAIResponseOutputMessageMCPListTools": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "mcp_list_tools", + "default": "mcp_list_tools" + }, + "server_label": { + "type": "string" + }, + "tools": { + "type": "array", + "items": { + "type": "object", + "properties": { + "input_schema": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "name": { + "type": "string" + }, + "description": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "input_schema", + "name" + ], + "title": "MCPListToolsTool" + } + } + }, + "additionalProperties": false, + "required": [ + "id", + "type", + "server_label", + "tools" + ], + "title": "OpenAIResponseOutputMessageMCPListTools" + }, "OpenAIResponseObjectStream": { "oneOf": [ { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 74c4852d4..4e4f09eb0 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -5075,12 +5075,12 @@ components: "OpenAIResponseOutputMessageFunctionToolCall": type: object properties: - arguments: - type: string call_id: type: string name: type: string + arguments: + type: string type: type: string const: function_call @@ -5091,12 +5091,10 @@ components: type: string additionalProperties: false required: - - arguments - call_id - name + - arguments - type - - id - - status title: >- OpenAIResponseOutputMessageFunctionToolCall "OpenAIResponseOutputMessageWebSearchToolCall": @@ -5214,12 +5212,85 @@ components: - $ref: '#/components/schemas/OpenAIResponseMessage' - $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' + - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' + - $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' discriminator: propertyName: type mapping: message: '#/components/schemas/OpenAIResponseMessage' web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' + mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall' + mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' + OpenAIResponseOutputMessageMCPCall: + type: object + properties: + id: + type: string + type: + type: string + const: mcp_call + default: mcp_call + arguments: + type: string + name: + type: string + server_label: + type: string + error: + type: string + output: + type: string + additionalProperties: false + required: + - id + - type + - arguments + - name + - server_label + title: OpenAIResponseOutputMessageMCPCall + OpenAIResponseOutputMessageMCPListTools: + type: object + properties: + id: + type: string + type: + type: string + const: mcp_list_tools + default: mcp_list_tools + server_label: + type: string + tools: + type: array + items: + type: object + properties: + input_schema: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + name: + type: string + description: + type: string + additionalProperties: false + required: + - input_schema + - name + title: MCPListToolsTool + additionalProperties: false + required: + - id + - type + - server_label + - tools + title: OpenAIResponseOutputMessageMCPListTools OpenAIResponseObjectStream: oneOf: - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index 7740b5643..675e8f3ff 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -10,6 +10,9 @@ from pydantic import BaseModel, Field from llama_stack.schema_utils import json_schema_type, register_schema +# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably +# take their YAML and generate this file automatically. Their YAML is available. + @json_schema_type class OpenAIResponseError(BaseModel): @@ -79,16 +82,45 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel): @json_schema_type class OpenAIResponseOutputMessageFunctionToolCall(BaseModel): - arguments: str call_id: str name: str + arguments: str type: Literal["function_call"] = "function_call" + id: str | None = None + status: str | None = None + + +@json_schema_type +class OpenAIResponseOutputMessageMCPCall(BaseModel): id: str - status: str + type: Literal["mcp_call"] = "mcp_call" + arguments: str + name: str + server_label: str + error: str | None = None + output: str | None = None + + +class MCPListToolsTool(BaseModel): + input_schema: dict[str, Any] + name: str + description: str | None = None + + +@json_schema_type +class OpenAIResponseOutputMessageMCPListTools(BaseModel): + id: str + type: Literal["mcp_list_tools"] = "mcp_list_tools" + server_label: str + tools: list[MCPListToolsTool] OpenAIResponseOutput = Annotated[ - OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseMessage + | OpenAIResponseOutputMessageWebSearchToolCall + | OpenAIResponseOutputMessageFunctionToolCall + | OpenAIResponseOutputMessageMCPCall + | OpenAIResponseOutputMessageMCPListTools, Field(discriminator="type"), ] register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput") diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py index 9acda3b8c..5f27ef906 100644 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py @@ -14,6 +14,7 @@ from pydantic import BaseModel from llama_stack.apis.agents import Order from llama_stack.apis.agents.openai_responses import ( + AllowedToolsFilter, ListOpenAIResponseInputItem, ListOpenAIResponseObject, OpenAIResponseInput, @@ -22,7 +23,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseInputMessageContentImage, OpenAIResponseInputMessageContentText, OpenAIResponseInputTool, - OpenAIResponseInputToolFunction, + OpenAIResponseInputToolMCP, OpenAIResponseMessage, OpenAIResponseObject, OpenAIResponseObjectStream, @@ -51,11 +52,12 @@ from llama_stack.apis.inference.inference import ( OpenAIToolMessageParam, OpenAIUserMessageParam, ) -from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime +from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool from llama_stack.providers.utils.responses.responses_store import ResponsesStore +from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools logger = get_logger(name=__name__, category="openai_responses") @@ -168,6 +170,15 @@ class OpenAIResponsePreviousResponseWithInputItems(BaseModel): response: OpenAIResponseObject +class ChatCompletionContext(BaseModel): + model: str + messages: list[OpenAIMessageParam] + tools: list[ChatCompletionToolParam] | None = None + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] + stream: bool + temperature: float | None + + class OpenAIResponsesImpl: def __init__( self, @@ -255,13 +266,32 @@ class OpenAIResponsesImpl: temperature: float | None = None, tools: list[OpenAIResponseInputTool] | None = None, ): + output_messages: list[OpenAIResponseOutput] = [] + stream = False if stream is None else stream + # Huge TODO: we need to run this in a loop, until morale improves + + # Create context to run "chat completion" input = await self._prepend_previous_response(input, previous_response_id) messages = await _convert_response_input_to_chat_messages(input) await self._prepend_instructions(messages, instructions) - chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None + chat_tools, mcp_tool_to_server, mcp_list_message = ( + await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None) + ) + if mcp_list_message: + output_messages.append(mcp_list_message) + ctx = ChatCompletionContext( + model=model, + messages=messages, + tools=chat_tools, + mcp_tool_to_server=mcp_tool_to_server, + stream=stream, + temperature=temperature, + ) + + # Run inference chat_response = await self.inference_api.openai_chat_completion( model=model, messages=messages, @@ -270,6 +300,7 @@ class OpenAIResponsesImpl: temperature=temperature, ) + # Collect output if stream: # TODO: refactor this into a separate method that handles streaming chat_response_id = "" @@ -328,11 +359,11 @@ class OpenAIResponsesImpl: # dump and reload to map to our pydantic types chat_response = OpenAIChatCompletion(**chat_response.model_dump()) - output_messages: list[OpenAIResponseOutput] = [] + # Execute tool calls if any for choice in chat_response.choices: if choice.message.tool_calls and tools: # Assume if the first tool is a function, all tools are functions - if isinstance(tools[0], OpenAIResponseInputToolFunction): + if tools[0].type == "function": for tool_call in choice.message.tool_calls: output_messages.append( OpenAIResponseOutputMessageFunctionToolCall( @@ -344,11 +375,12 @@ class OpenAIResponsesImpl: ) ) else: - output_messages.extend( - await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature) - ) + tool_messages = await self._execute_tool_and_return_final_output(choice, ctx) + output_messages.extend(tool_messages) else: output_messages.append(await _convert_chat_choice_to_response_message(choice)) + + # Create response object response = OpenAIResponseObject( created_at=chat_response.created, id=f"resp-{uuid.uuid4()}", @@ -359,9 +391,8 @@ class OpenAIResponsesImpl: ) logger.debug(f"OpenAI Responses response: {response}") + # Store response if requested if store: - # Store in kvstore - new_input_id = f"msg_{uuid.uuid4()}" if isinstance(input, str): # synthesize a message from the input string @@ -403,7 +434,36 @@ class OpenAIResponsesImpl: async def _convert_response_tools_to_chat_tools( self, tools: list[OpenAIResponseInputTool] - ) -> list[ChatCompletionToolParam]: + ) -> tuple[ + list[ChatCompletionToolParam], + dict[str, OpenAIResponseInputToolMCP], + OpenAIResponseOutput | None, + ]: + from llama_stack.apis.agents.openai_responses import ( + MCPListToolsTool, + OpenAIResponseOutputMessageMCPListTools, + ) + from llama_stack.apis.tools.tools import Tool + + mcp_tool_to_server = {} + + def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam: + tool_def = ToolDefinition( + tool_name=tool_name, + description=tool.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool.parameters + }, + ) + return convert_tooldef_to_openai_tool(tool_def) + + mcp_list_message = None chat_tools: list[ChatCompletionToolParam] = [] for input_tool in tools: # TODO: Handle other tool types @@ -412,91 +472,95 @@ class OpenAIResponsesImpl: elif input_tool.type == "web_search": tool_name = "web_search" tool = await self.tool_groups_api.get_tool(tool_name) - tool_def = ToolDefinition( - tool_name=tool_name, - description=tool.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - for param in tool.parameters - }, + if not tool: + raise ValueError(f"Tool {tool_name} not found") + chat_tools.append(make_openai_tool(tool_name, tool)) + elif input_tool.type == "mcp": + always_allowed = None + never_allowed = None + if input_tool.allowed_tools: + if isinstance(input_tool.allowed_tools, list): + always_allowed = input_tool.allowed_tools + elif isinstance(input_tool.allowed_tools, AllowedToolsFilter): + always_allowed = input_tool.allowed_tools.always + never_allowed = input_tool.allowed_tools.never + + tool_defs = await list_mcp_tools( + endpoint=input_tool.server_url, + headers=input_tool.headers or {}, ) - chat_tool = convert_tooldef_to_openai_tool(tool_def) - chat_tools.append(chat_tool) + + mcp_list_message = OpenAIResponseOutputMessageMCPListTools( + id=f"mcp_list_{uuid.uuid4()}", + status="completed", + server_label=input_tool.server_label, + tools=[], + ) + for t in tool_defs.data: + if never_allowed and t.name in never_allowed: + continue + if not always_allowed or t.name in always_allowed: + chat_tools.append(make_openai_tool(t.name, t)) + if t.name in mcp_tool_to_server: + raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}") + mcp_tool_to_server[t.name] = input_tool + mcp_list_message.tools.append( + MCPListToolsTool( + name=t.name, + description=t.description, + input_schema={ + "type": "object", + "properties": { + p.name: { + "type": p.parameter_type, + "description": p.description, + } + for p in t.parameters + }, + "required": [p.name for p in t.parameters if p.required], + }, + ) + ) else: raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}") - return chat_tools + return chat_tools, mcp_tool_to_server, mcp_list_message async def _execute_tool_and_return_final_output( self, - model_id: str, - stream: bool, choice: OpenAIChoice, - messages: list[OpenAIMessageParam], - temperature: float, + ctx: ChatCompletionContext, ) -> list[OpenAIResponseOutput]: output_messages: list[OpenAIResponseOutput] = [] - # If the choice is not an assistant message, we don't need to execute any tools if not isinstance(choice.message, OpenAIAssistantMessageParam): return output_messages - # If the assistant message doesn't have any tool calls, we don't need to execute any tools if not choice.message.tool_calls: return output_messages - # Copy the messages list to avoid mutating the original list - messages = messages.copy() + next_turn_messages = ctx.messages.copy() # Add the assistant message with tool_calls response to the messages list - messages.append(choice.message) + next_turn_messages.append(choice.message) for tool_call in choice.message.tool_calls: - tool_call_id = tool_call.id - function = tool_call.function - - # If for some reason the tool call doesn't have a function or id, we can't execute it - if not function or not tool_call_id: - continue - # TODO: telemetry spans for tool calls - result = await self._execute_tool_call(function) - - # Handle tool call failure - if not result: - output_messages.append( - OpenAIResponseOutputMessageWebSearchToolCall( - id=tool_call_id, - status="failed", - ) - ) - continue - - output_messages.append( - OpenAIResponseOutputMessageWebSearchToolCall( - id=tool_call_id, - status="completed", - ), - ) - - result_content = "" - # TODO: handle other result content types and lists - if isinstance(result.content, str): - result_content = result.content - messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id)) + tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx) + if tool_call_log: + output_messages.append(tool_call_log) + if further_input: + next_turn_messages.append(further_input) tool_results_chat_response = await self.inference_api.openai_chat_completion( - model=model_id, - messages=messages, - stream=stream, - temperature=temperature, + model=ctx.model, + messages=next_turn_messages, + stream=ctx.stream, + temperature=ctx.temperature, ) - # type cast to appease mypy + # type cast to appease mypy: this is needed because we don't handle streaming properly :) tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response) + + # Huge TODO: these are NOT the final outputs, we must keep the loop going tool_final_outputs = [ await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices ] @@ -506,15 +570,86 @@ class OpenAIResponsesImpl: async def _execute_tool_call( self, - function: OpenAIChatCompletionToolCallFunction, - ) -> ToolInvocationResult | None: - if not function.name: - return None - function_args = json.loads(function.arguments) if function.arguments else {} - logger.info(f"executing tool call: {function.name} with args: {function_args}") - result = await self.tool_runtime_api.invoke_tool( - tool_name=function.name, - kwargs=function_args, + tool_call: OpenAIChatCompletionToolCall, + ctx: ChatCompletionContext, + ) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]: + from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, ) - logger.debug(f"tool call {function.name} completed with result: {result}") - return result + + tool_call_id = tool_call.id + function = tool_call.function + + if not function or not tool_call_id or not function.name: + return None, None + + error_exc = None + result = None + try: + if function.name in ctx.mcp_tool_to_server: + mcp_tool = ctx.mcp_tool_to_server[function.name] + result = await invoke_mcp_tool( + endpoint=mcp_tool.server_url, + headers=mcp_tool.headers or {}, + tool_name=function.name, + kwargs=json.loads(function.arguments) if function.arguments else {}, + ) + else: + result = await self.tool_runtime_api.invoke_tool( + tool_name=function.name, + kwargs=json.loads(function.arguments) if function.arguments else {}, + ) + except Exception as e: + error_exc = e + + if function.name in ctx.mcp_tool_to_server: + from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall + + message = OpenAIResponseOutputMessageMCPCall( + id=tool_call_id, + arguments=function.arguments, + name=function.name, + server_label=ctx.mcp_tool_to_server[function.name].server_label, + ) + if error_exc: + message.error = str(error_exc) + elif (result.error_code and result.error_code > 0) or result.error_message: + message.error = f"Error (code {result.error_code}): {result.error_message}" + elif result.content: + message.output = interleaved_content_as_str(result.content) + else: + if function.name == "web_search": + message = OpenAIResponseOutputMessageWebSearchToolCall( + id=tool_call_id, + status="completed", + ) + if error_exc or (result.error_code and result.error_code > 0) or result.error_message: + message.status = "failed" + else: + raise ValueError(f"Unknown tool {function.name} called") + + input_message = None + if result and result.content: + if isinstance(result.content, str): + content = result.content + elif isinstance(result.content, list): + from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem + + content = [] + for item in result.content: + if isinstance(item, TextContentItem): + part = OpenAIChatCompletionContentPartTextParam(text=item.text) + elif isinstance(item, ImageContentItem): + if item.image.data: + url = f"data:image;base64,{item.image.data}" + else: + url = item.image.url + part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url)) + else: + raise ValueError(f"Unknown result content type: {type(item)}") + content.append(part) + else: + raise ValueError(f"Unknown result content type: {type(result.content)}") + input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) + + return message, input_message diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index 340e90ca1..3f0b9a188 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -4,53 +4,26 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from contextlib import asynccontextmanager from typing import Any from urllib.parse import urlparse -import exceptiongroup -import httpx -from mcp import ClientSession -from mcp import types as mcp_types -from mcp.client.sse import sse_client - -from llama_stack.apis.common.content_types import URL, ImageContentItem, TextContentItem +from llama_stack.apis.common.content_types import URL from llama_stack.apis.datatypes import Api from llama_stack.apis.tools import ( ListToolDefsResponse, - ToolDef, ToolInvocationResult, - ToolParameter, ToolRuntime, ) -from llama_stack.distribution.datatypes import AuthenticationRequiredError from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.utils.tools.mcp import convert_header_list_to_dict, invoke_mcp_tool, list_mcp_tools from .config import MCPProviderConfig logger = get_logger(__name__, category="tools") -@asynccontextmanager -async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): - try: - async with sse_client(endpoint, headers=headers) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - yield session - except BaseException as e: - if isinstance(e, exceptiongroup.BaseExceptionGroup): - for exc in e.exceptions: - if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401: - raise AuthenticationRequiredError(exc) from exc - elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401: - raise AuthenticationRequiredError(e) from e - - raise - - class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]): self.config = config @@ -64,32 +37,8 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee # this endpoint should be retrieved by getting the tool group right? if mcp_endpoint is None: raise ValueError("mcp_endpoint is required") - headers = await self.get_headers_from_request(mcp_endpoint.uri) - tools = [] - async with sse_client_wrapper(mcp_endpoint.uri, headers) as session: - tools_result = await session.list_tools() - for tool in tools_result.tools: - parameters = [] - for param_name, param_schema in tool.inputSchema.get("properties", {}).items(): - parameters.append( - ToolParameter( - name=param_name, - parameter_type=param_schema.get("type", "string"), - description=param_schema.get("description", ""), - ) - ) - tools.append( - ToolDef( - name=tool.name, - description=tool.description, - parameters=parameters, - metadata={ - "endpoint": mcp_endpoint.uri, - }, - ) - ) - return ListToolDefsResponse(data=tools) + return await list_mcp_tools(mcp_endpoint.uri, headers) async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: tool = await self.tool_store.get_tool(tool_name) @@ -100,23 +49,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") headers = await self.get_headers_from_request(endpoint) - async with sse_client_wrapper(endpoint, headers) as session: - result = await session.call_tool(tool.identifier, kwargs) - - content = [] - for item in result.content: - if isinstance(item, mcp_types.TextContent): - content.append(TextContentItem(text=item.text)) - elif isinstance(item, mcp_types.ImageContent): - content.append(ImageContentItem(image=item.data)) - elif isinstance(item, mcp_types.EmbeddedResource): - logger.warning(f"EmbeddedResource is not supported: {item}") - else: - raise ValueError(f"Unknown content type: {type(item)}") - return ToolInvocationResult( - content=content, - error_code=1 if result.isError else 0, - ) + return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs) async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]: def canonicalize_uri(uri: str) -> str: @@ -129,9 +62,5 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, Nee for uri, values in provider_data.mcp_headers.items(): if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri): continue - for entry in values: - parts = entry.split(":") - if len(parts) == 2: - k, v = parts - headers[k.strip()] = v.strip() + headers.update(convert_header_list_to_dict(values)) return headers diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py new file mode 100644 index 000000000..7bc372f07 --- /dev/null +++ b/llama_stack/providers/utils/tools/mcp.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from contextlib import asynccontextmanager +from typing import Any + +try: + # for python < 3.11 + import exceptiongroup + + BaseExceptionGroup = exceptiongroup.BaseExceptionGroup +except ImportError: + pass + +import httpx +from mcp import ClientSession +from mcp import types as mcp_types +from mcp.client.sse import sse_client + +from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem +from llama_stack.apis.tools import ( + ListToolDefsResponse, + ToolDef, + ToolInvocationResult, + ToolParameter, +) +from llama_stack.distribution.datatypes import AuthenticationRequiredError +from llama_stack.log import get_logger + +logger = get_logger(__name__, category="tools") + + +@asynccontextmanager +async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): + try: + async with sse_client(endpoint, headers=headers) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + yield session + except BaseException as e: + if isinstance(e, BaseExceptionGroup): + for exc in e.exceptions: + if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401: + raise AuthenticationRequiredError(exc) from exc + elif isinstance(e, httpx.HTTPStatusError) and e.response.status_code == 401: + raise AuthenticationRequiredError(e) from e + + raise + + +def convert_header_list_to_dict(header_list: list[str]) -> dict[str, str]: + headers = {} + for header in header_list: + parts = header.split(":") + if len(parts) == 2: + k, v = parts + headers[k.strip()] = v.strip() + return headers + + +async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: + tools = [] + async with sse_client_wrapper(endpoint, headers) as session: + tools_result = await session.list_tools() + for tool in tools_result.tools: + parameters = [] + for param_name, param_schema in tool.inputSchema.get("properties", {}).items(): + parameters.append( + ToolParameter( + name=param_name, + parameter_type=param_schema.get("type", "string"), + description=param_schema.get("description", ""), + ) + ) + tools.append( + ToolDef( + name=tool.name, + description=tool.description, + parameters=parameters, + metadata={ + "endpoint": endpoint, + }, + ) + ) + return ListToolDefsResponse(data=tools) + + +async def invoke_mcp_tool( + endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any] +) -> ToolInvocationResult: + async with sse_client_wrapper(endpoint, headers) as session: + result = await session.call_tool(tool_name, kwargs) + + content: list[InterleavedContentItem] = [] + for item in result.content: + if isinstance(item, mcp_types.TextContent): + content.append(TextContentItem(text=item.text)) + elif isinstance(item, mcp_types.ImageContent): + content.append(ImageContentItem(image=item.data)) + elif isinstance(item, mcp_types.EmbeddedResource): + logger.warning(f"EmbeddedResource is not supported: {item}") + else: + raise ValueError(f"Unknown content type: {type(item)}") + return ToolInvocationResult( + content=content, + error_code=1 if result.isError else 0, + ) diff --git a/pyproject.toml b/pyproject.toml index 0b5c7f6df..8d8137233 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ unit = [ "aiosqlite", "aiohttp", "pypdf", + "mcp", "chardet", "qdrant-client", "opentelemetry-exporter-otlp-proto-http", diff --git a/tests/common/mcp.py b/tests/common/mcp.py new file mode 100644 index 000000000..fd7040c6c --- /dev/null +++ b/tests/common/mcp.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# we want the mcp server to be authenticated OR not, depends +from contextlib import contextmanager + +# Unfortunately the toolgroup id must be tied to the tool names because the registry +# indexes on both toolgroups and tools independently (and not jointly). That really +# needs to be fixed. +MCP_TOOLGROUP_ID = "mcp::localmcp" + + +@contextmanager +def make_mcp_server(required_auth_token: str | None = None): + import threading + import time + + import httpx + import uvicorn + from mcp import types + from mcp.server.fastmcp import Context, FastMCP + from mcp.server.sse import SseServerTransport + from starlette.applications import Starlette + from starlette.responses import Response + from starlette.routing import Mount, Route + + server = FastMCP("FastMCP Test Server", log_level="WARNING") + + @server.tool() + async def greet_everyone( + url: str, ctx: Context + ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + return [types.TextContent(type="text", text="Hello, world!")] + + @server.tool() + async def get_boiling_point(liquid_name: str, celcius: bool = True) -> int: + """ + Returns the boiling point of a liquid in Celcius or Fahrenheit. + + :param liquid_name: The name of the liquid + :param celcius: Whether to return the boiling point in Celcius + :return: The boiling point of the liquid in Celcius or Fahrenheit + """ + if liquid_name.lower() == "polyjuice": + if celcius: + return -100 + else: + return -212 + else: + return -1 + + sse = SseServerTransport("/messages/") + + async def handle_sse(request): + from starlette.exceptions import HTTPException + + auth_header = request.headers.get("Authorization") + auth_token = None + if auth_header and auth_header.startswith("Bearer "): + auth_token = auth_header.split(" ")[1] + + if required_auth_token and auth_token != required_auth_token: + raise HTTPException(status_code=401, detail="Unauthorized") + + async with sse.connect_sse(request.scope, request.receive, request._send) as streams: + await server._mcp_server.run( + streams[0], + streams[1], + server._mcp_server.create_initialization_options(), + ) + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ], + ) + + def get_open_port(): + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + port = get_open_port() + + # make uvicorn logs be less verbose + config = uvicorn.Config(app, host="0.0.0.0", port=port, log_level="warning") + server_instance = uvicorn.Server(config) + app.state.uvicorn_server = server_instance + + def run_server(): + server_instance.run() + + # Start the server in a new thread + server_thread = threading.Thread(target=run_server, daemon=True) + server_thread.start() + + # Polling until the server is ready + timeout = 10 + start_time = time.time() + + server_url = f"http://localhost:{port}/sse" + while time.time() - start_time < timeout: + try: + response = httpx.get(server_url) + if response.status_code in [200, 401]: + break + except httpx.RequestError: + pass + time.sleep(0.1) + + try: + yield {"server_url": server_url} + finally: + print("Telling SSE server to exit") + server_instance.should_exit = True + time.sleep(0.5) + + # Force shutdown if still running + if server_thread.is_alive(): + try: + if hasattr(server_instance, "servers") and server_instance.servers: + for srv in server_instance.servers: + srv.close() + + # Wait for graceful shutdown + server_thread.join(timeout=3) + if server_thread.is_alive(): + print("Warning: Server thread still alive after shutdown attempt") + except Exception as e: + print(f"Error during server shutdown: {e}") + + # CRITICAL: Reset SSE global state to prevent event loop contamination + # Reset the SSE AppStatus singleton that stores anyio.Event objects + from sse_starlette.sse import AppStatus + + AppStatus.should_exit = False + AppStatus.should_exit_event = None + print("SSE server exited") diff --git a/tests/integration/tool_runtime/test_mcp.py b/tests/integration/tool_runtime/test_mcp.py index e553c6a0b..dd8a6d823 100644 --- a/tests/integration/tool_runtime/test_mcp.py +++ b/tests/integration/tool_runtime/test_mcp.py @@ -5,117 +5,43 @@ # the root directory of this source tree. import json -import socket -import threading -import time -import httpx -import mcp.types as types import pytest -import uvicorn from llama_stack_client import Agent -from mcp.server.fastmcp import Context, FastMCP -from mcp.server.sse import SseServerTransport -from starlette.applications import Starlette -from starlette.exceptions import HTTPException -from starlette.responses import Response -from starlette.routing import Mount, Route from llama_stack import LlamaStackAsLibraryClient +from llama_stack.apis.models import ModelType from llama_stack.distribution.datatypes import AuthenticationRequiredError AUTH_TOKEN = "test-token" +from tests.common.mcp import MCP_TOOLGROUP_ID, make_mcp_server -@pytest.fixture(scope="module") + +@pytest.fixture(scope="function") def mcp_server(): - server = FastMCP("FastMCP Test Server") - - @server.tool() - async def greet_everyone( - url: str, ctx: Context - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: - return [types.TextContent(type="text", text="Hello, world!")] - - sse = SseServerTransport("/messages/") - - async def handle_sse(request): - auth_header = request.headers.get("Authorization") - auth_token = None - if auth_header and auth_header.startswith("Bearer "): - auth_token = auth_header.split(" ")[1] - - if auth_token != AUTH_TOKEN: - raise HTTPException(status_code=401, detail="Unauthorized") - - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await server._mcp_server.run( - streams[0], - streams[1], - server._mcp_server.create_initialization_options(), - ) - return Response() - - app = Starlette( - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ], - ) - - def get_open_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("", 0)) - return sock.getsockname()[1] - - port = get_open_port() - - config = uvicorn.Config(app, host="0.0.0.0", port=port) - server_instance = uvicorn.Server(config) - app.state.uvicorn_server = server_instance - - def run_server(): - server_instance.run() - - # Start the server in a new thread - server_thread = threading.Thread(target=run_server, daemon=True) - server_thread.start() - - # Polling until the server is ready - timeout = 10 - start_time = time.time() - - while time.time() - start_time < timeout: - try: - response = httpx.get(f"http://localhost:{port}/sse") - if response.status_code == 401: - break - except httpx.RequestError: - pass - time.sleep(0.1) - - yield port - - # Tell server to exit - server_instance.should_exit = True - server_thread.join(timeout=5) + with make_mcp_server(required_auth_token=AUTH_TOKEN) as mcp_server_info: + yield mcp_server_info def test_mcp_invocation(llama_stack_client, mcp_server): - port = mcp_server - test_toolgroup_id = "remote::mcptest" + if not isinstance(llama_stack_client, LlamaStackAsLibraryClient): + pytest.skip("The local MCP server only reliably reachable from library client.") + + test_toolgroup_id = MCP_TOOLGROUP_ID + uri = mcp_server["server_url"] # registering itself should fail since it requires listing tools with pytest.raises(Exception, match="Unauthorized"): llama_stack_client.toolgroups.register( toolgroup_id=test_toolgroup_id, provider_id="model-context-protocol", - mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"), + mcp_endpoint=dict(uri=uri), ) provider_data = { "mcp_headers": { - f"http://localhost:{port}/sse": [ + uri: [ f"Authorization: Bearer {AUTH_TOKEN}", ], }, @@ -133,24 +59,18 @@ def test_mcp_invocation(llama_stack_client, mcp_server): llama_stack_client.toolgroups.register( toolgroup_id=test_toolgroup_id, provider_id="model-context-protocol", - mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"), + mcp_endpoint=dict(uri=uri), extra_headers=auth_headers, ) response = llama_stack_client.tools.list( toolgroup_id=test_toolgroup_id, extra_headers=auth_headers, ) - assert len(response) == 1 - assert response[0].identifier == "greet_everyone" - assert response[0].type == "tool" - assert len(response[0].parameters) == 1 - p = response[0].parameters[0] - assert p.name == "url" - assert p.parameter_type == "string" - assert p.required + assert len(response) == 2 + assert {t.identifier for t in response} == {"greet_everyone", "get_boiling_point"} response = llama_stack_client.tool_runtime.invoke_tool( - tool_name=response[0].identifier, + tool_name="greet_everyone", kwargs=dict(url="https://www.google.com"), extra_headers=auth_headers, ) @@ -159,7 +79,9 @@ def test_mcp_invocation(llama_stack_client, mcp_server): assert content[0].type == "text" assert content[0].text == "Hello, world!" - models = llama_stack_client.models.list() + models = [ + m for m in llama_stack_client.models.list() if m.model_type == ModelType.llm and "guard" not in m.identifier + ] model_id = models[0].identifier print(f"Using model: {model_id}") agent = Agent( @@ -174,7 +96,7 @@ def test_mcp_invocation(llama_stack_client, mcp_server): messages=[ { "role": "user", - "content": "Yo. Use tools.", + "content": "Say hi to the world. Use tools to do so.", } ], stream=False, @@ -196,7 +118,6 @@ def test_mcp_invocation(llama_stack_client, mcp_server): third = steps[2] assert third.step_type == "inference" - assert len(third.api_model_response.tool_calls) == 0 # when streaming, we currently don't check auth headers upfront and fail the request # early. but we should at least be generating a 401 later in the process. @@ -205,7 +126,7 @@ def test_mcp_invocation(llama_stack_client, mcp_server): messages=[ { "role": "user", - "content": "Yo. Use tools.", + "content": "What is the boiling point of polyjuice? Use tools to answer.", } ], stream=True, diff --git a/tests/integration/tool_runtime/test_registration.py b/tests/integration/tool_runtime/test_registration.py index b36237d05..b8cbd964a 100644 --- a/tests/integration/tool_runtime/test_registration.py +++ b/tests/integration/tool_runtime/test_registration.py @@ -4,120 +4,54 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import socket -import threading -import time - -import httpx -import mcp.types as types import pytest -import uvicorn -from mcp.server.fastmcp import Context, FastMCP -from mcp.server.sse import SseServerTransport -from starlette.applications import Starlette -from starlette.routing import Mount, Route + +from llama_stack import LlamaStackAsLibraryClient +from tests.common.mcp import MCP_TOOLGROUP_ID, make_mcp_server -@pytest.fixture(scope="module") -def mcp_server(): - server = FastMCP("FastMCP Test Server") +def test_register_and_unregister_toolgroup(llama_stack_client): + # TODO: make this work for http client also but you need to ensure + # the MCP server is reachable from llama stack server + if not isinstance(llama_stack_client, LlamaStackAsLibraryClient): + pytest.skip("The local MCP server only reliably reachable from library client.") - @server.tool() - async def fetch(url: str, ctx: Context) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: - headers = {"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"} - async with httpx.AsyncClient(follow_redirects=True, headers=headers) as client: - response = await client.get(url) - response.raise_for_status() - return [types.TextContent(type="text", text=response.text)] - - sse = SseServerTransport("/messages/") - - async def handle_sse(request): - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: - await server._mcp_server.run( - streams[0], - streams[1], - server._mcp_server.create_initialization_options(), - ) - - app = Starlette( - debug=True, - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ], - ) - - def get_open_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("", 0)) - return sock.getsockname()[1] - - port = get_open_port() - - def run_server(): - uvicorn.run(app, host="0.0.0.0", port=port) - - # Start the server in a new thread - server_thread = threading.Thread(target=run_server, daemon=True) - server_thread.start() - - # Polling until the server is ready - timeout = 10 - start_time = time.time() - - while time.time() - start_time < timeout: - try: - response = httpx.get(f"http://localhost:{port}/sse") - if response.status_code == 200: - break - except (httpx.RequestError, httpx.HTTPStatusError): - pass - time.sleep(0.1) - - yield port - - -def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server): - """ - Integration test for registering and unregistering a toolgroup using the ToolGroups API. - """ - port = mcp_server - test_toolgroup_id = "remote::web-fetch" + test_toolgroup_id = MCP_TOOLGROUP_ID provider_id = "model-context-protocol" - # Cleanup before running the test - toolgroups = llama_stack_client.toolgroups.list() - for toolgroup in toolgroups: - if toolgroup.identifier == test_toolgroup_id: - llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id) + with make_mcp_server() as mcp_server_info: + # Cleanup before running the test + toolgroups = llama_stack_client.toolgroups.list() + for toolgroup in toolgroups: + if toolgroup.identifier == test_toolgroup_id: + llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id) - # Register the toolgroup - llama_stack_client.toolgroups.register( - toolgroup_id=test_toolgroup_id, - provider_id=provider_id, - mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"), - ) + # Register the toolgroup + llama_stack_client.toolgroups.register( + toolgroup_id=test_toolgroup_id, + provider_id=provider_id, + mcp_endpoint=dict(uri=mcp_server_info["server_url"]), + ) - # Verify registration - registered_toolgroup = llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) - assert registered_toolgroup is not None - assert registered_toolgroup.identifier == test_toolgroup_id - assert registered_toolgroup.provider_id == provider_id + # Verify registration + registered_toolgroup = llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) + assert registered_toolgroup is not None + assert registered_toolgroup.identifier == test_toolgroup_id + assert registered_toolgroup.provider_id == provider_id - # Verify tools listing - tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) - assert isinstance(tools_list_response, list) - assert tools_list_response + # Verify tools listing + tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) + assert isinstance(tools_list_response, list) + assert tools_list_response - # Unregister the toolgroup - llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id) + # Unregister the toolgroup + llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id) - # Verify it is unregistered - with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"): - llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) + # Verify it is unregistered + with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"): + llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) - # Verify tools are also unregistered - unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) - assert isinstance(unregister_tools_list_response, list) - assert not unregister_tools_list_response + # Verify tools are also unregistered + unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id) + assert isinstance(unregister_tools_list_response, list) + assert not unregister_tools_list_response diff --git a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml index 262d82526..d8b8d40c5 100644 --- a/tests/verifications/openai_api/fixtures/test_cases/responses.yaml +++ b/tests/verifications/openai_api/fixtures/test_cases/responses.yaml @@ -31,6 +31,18 @@ test_response_web_search: search_context_size: "low" output: "128" +test_response_mcp_tool: + test_name: test_response_mcp_tool + test_params: + case: + - case_id: "boiling_point_tool" + input: "What is the boiling point of polyjuice?" + tools: + - type: mcp + server_label: "localmcp" + server_url: "" + output: "Hello, world!" + test_response_custom_tool: test_name: test_response_custom_tool test_params: diff --git a/tests/verifications/openai_api/test_responses.py b/tests/verifications/openai_api/test_responses.py index e279b9b38..da6ed85bc 100644 --- a/tests/verifications/openai_api/test_responses.py +++ b/tests/verifications/openai_api/test_responses.py @@ -4,9 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json +import httpx import pytest +from llama_stack import LlamaStackAsLibraryClient +from llama_stack.distribution.datatypes import AuthenticationRequiredError +from tests.common.mcp import make_mcp_server from tests.verifications.openai_api.fixtures.fixtures import ( case_id_generator, get_base_test_name, @@ -124,6 +129,79 @@ def test_response_non_streaming_web_search(request, openai_client, model, provid assert case["output"].lower() in response.output_text.lower().strip() +@pytest.mark.parametrize( + "case", + responses_test_cases["test_response_mcp_tool"]["test_params"]["case"], + ids=case_id_generator, +) +def test_response_non_streaming_mcp_tool(request, openai_client, model, provider, verification_config, case): + test_name_base = get_base_test_name(request) + if should_skip_test(verification_config, provider, model, test_name_base): + pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") + + with make_mcp_server() as mcp_server_info: + tools = case["tools"] + for tool in tools: + if tool["type"] == "mcp": + tool["server_url"] = mcp_server_info["server_url"] + + response = openai_client.responses.create( + model=model, + input=case["input"], + tools=tools, + stream=False, + ) + assert len(response.output) >= 3 + list_tools = response.output[0] + assert list_tools.type == "mcp_list_tools" + assert list_tools.server_label == "localmcp" + assert len(list_tools.tools) == 2 + assert {t["name"] for t in list_tools.tools} == {"get_boiling_point", "greet_everyone"} + + call = response.output[1] + assert call.type == "mcp_call" + assert call.name == "get_boiling_point" + assert json.loads(call.arguments) == {"liquid_name": "polyjuice", "celcius": True} + assert call.error is None + assert "-100" in call.output + + message = response.output[2] + text_content = message.content[0].text + assert "boiling point" in text_content.lower() + + with make_mcp_server(required_auth_token="test-token") as mcp_server_info: + tools = case["tools"] + for tool in tools: + if tool["type"] == "mcp": + tool["server_url"] = mcp_server_info["server_url"] + + exc_type = ( + AuthenticationRequiredError + if isinstance(openai_client, LlamaStackAsLibraryClient) + else httpx.HTTPStatusError + ) + with pytest.raises(exc_type): + openai_client.responses.create( + model=model, + input=case["input"], + tools=tools, + stream=False, + ) + + for tool in tools: + if tool["type"] == "mcp": + tool["server_url"] = mcp_server_info["server_url"] + tool["headers"] = {"Authorization": "Bearer test-token"} + + response = openai_client.responses.create( + model=model, + input=case["input"], + tools=tools, + stream=False, + ) + assert len(response.output) >= 3 + + @pytest.mark.parametrize( "case", responses_test_cases["test_response_custom_tool"]["test_params"]["case"], diff --git a/uv.lock b/uv.lock index f2c9d59c1..f6dae2944 100644 --- a/uv.lock +++ b/uv.lock @@ -1544,6 +1544,7 @@ unit = [ { name = "aiohttp" }, { name = "aiosqlite" }, { name = "chardet" }, + { name = "mcp" }, { name = "openai" }, { name = "opentelemetry-exporter-otlp-proto-http" }, { name = "pypdf" }, @@ -1576,6 +1577,7 @@ requires-dist = [ { name = "llama-stack-client", specifier = ">=0.2.7" }, { name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.7" }, { name = "mcp", marker = "extra == 'test'" }, + { name = "mcp", marker = "extra == 'unit'" }, { name = "myst-parser", marker = "extra == 'docs'" }, { name = "nbval", marker = "extra == 'dev'" }, { name = "openai", specifier = ">=1.66" },