feat: enable MCP execution in Responses impl (#2240)

## Test Plan

```
pytest -s -v 'tests/verifications/openai_api/test_responses.py' \
  --provider=stack:together --model meta-llama/Llama-4-Scout-17B-16E-Instruct
```
This commit is contained in:
Ashwin Bharambe 2025-05-24 14:20:42 -07:00 committed by GitHub
parent 66f09f24ed
commit 3faf1e4a79
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 865 additions and 382 deletions

View file

@ -4,120 +4,54 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import socket
import threading
import time
import httpx
import mcp.types as types
import pytest
import uvicorn
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.routing import Mount, Route
from llama_stack import LlamaStackAsLibraryClient
from tests.common.mcp import MCP_TOOLGROUP_ID, make_mcp_server
@pytest.fixture(scope="module")
def mcp_server():
server = FastMCP("FastMCP Test Server")
def test_register_and_unregister_toolgroup(llama_stack_client):
# TODO: make this work for http client also but you need to ensure
# the MCP server is reachable from llama stack server
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
pytest.skip("The local MCP server only reliably reachable from library client.")
@server.tool()
async def fetch(url: str, ctx: Context) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
headers = {"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"}
async with httpx.AsyncClient(follow_redirects=True, headers=headers) as client:
response = await client.get(url)
response.raise_for_status()
return [types.TextContent(type="text", text=response.text)]
sse = SseServerTransport("/messages/")
async def handle_sse(request):
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await server._mcp_server.run(
streams[0],
streams[1],
server._mcp_server.create_initialization_options(),
)
app = Starlette(
debug=True,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
port = get_open_port()
def run_server():
uvicorn.run(app, host="0.0.0.0", port=port)
# Start the server in a new thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Polling until the server is ready
timeout = 10
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = httpx.get(f"http://localhost:{port}/sse")
if response.status_code == 200:
break
except (httpx.RequestError, httpx.HTTPStatusError):
pass
time.sleep(0.1)
yield port
def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server):
"""
Integration test for registering and unregistering a toolgroup using the ToolGroups API.
"""
port = mcp_server
test_toolgroup_id = "remote::web-fetch"
test_toolgroup_id = MCP_TOOLGROUP_ID
provider_id = "model-context-protocol"
# Cleanup before running the test
toolgroups = llama_stack_client.toolgroups.list()
for toolgroup in toolgroups:
if toolgroup.identifier == test_toolgroup_id:
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
with make_mcp_server() as mcp_server_info:
# Cleanup before running the test
toolgroups = llama_stack_client.toolgroups.list()
for toolgroup in toolgroups:
if toolgroup.identifier == test_toolgroup_id:
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
# Register the toolgroup
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id=provider_id,
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
)
# Register the toolgroup
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id=provider_id,
mcp_endpoint=dict(uri=mcp_server_info["server_url"]),
)
# Verify registration
registered_toolgroup = llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
assert registered_toolgroup is not None
assert registered_toolgroup.identifier == test_toolgroup_id
assert registered_toolgroup.provider_id == provider_id
# Verify registration
registered_toolgroup = llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
assert registered_toolgroup is not None
assert registered_toolgroup.identifier == test_toolgroup_id
assert registered_toolgroup.provider_id == provider_id
# Verify tools listing
tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
assert isinstance(tools_list_response, list)
assert tools_list_response
# Verify tools listing
tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
assert isinstance(tools_list_response, list)
assert tools_list_response
# Unregister the toolgroup
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
# Unregister the toolgroup
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
# Verify it is unregistered
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
# Verify it is unregistered
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
# Verify tools are also unregistered
unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
assert isinstance(unregister_tools_list_response, list)
assert not unregister_tools_list_response
# Verify tools are also unregistered
unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
assert isinstance(unregister_tools_list_response, list)
assert not unregister_tools_list_response