feat: enable MCP execution in Response implementation

This commit is contained in:
Ashwin Bharambe 2025-05-22 20:21:47 -07:00
parent a411029d7e
commit 5937d94da5
9 changed files with 728 additions and 174 deletions

116
tests/common/mcp.py Normal file
View file

@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# we want the mcp server to be authenticated OR not, depends
from contextlib import contextmanager
@contextmanager
def make_mcp_server(required_auth_token: str | None = None):
import threading
import time
import httpx
import uvicorn
from mcp import types
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.responses import Response
from starlette.routing import Mount, Route
server = FastMCP("FastMCP Test Server")
@server.tool()
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
@server.tool()
async def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
"""
Returns the boiling point of a liquid in Celcius or Fahrenheit.
:param liquid_name: The name of the liquid
:param celcius: Whether to return the boiling point in Celcius
:return: The boiling point of the liquid in Celcius or Fahrenheit
"""
if liquid_name.lower() == "polyjuice":
if celcius:
return -100
else:
return -212
else:
return -1
sse = SseServerTransport("/messages/")
async def handle_sse(request):
from starlette.exceptions import HTTPException
auth_header = request.headers.get("Authorization")
auth_token = None
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.split(" ")[1]
if required_auth_token and auth_token != required_auth_token:
raise HTTPException(status_code=401, detail="Unauthorized")
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await server._mcp_server.run(
streams[0],
streams[1],
server._mcp_server.create_initialization_options(),
)
return Response()
app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
def get_open_port():
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
port = get_open_port()
config = uvicorn.Config(app, host="0.0.0.0", port=port)
server_instance = uvicorn.Server(config)
app.state.uvicorn_server = server_instance
def run_server():
server_instance.run()
# Start the server in a new thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Polling until the server is ready
timeout = 10
start_time = time.time()
server_url = f"http://localhost:{port}/sse"
while time.time() - start_time < timeout:
try:
response = httpx.get(server_url)
if response.status_code in [200, 401]:
break
except httpx.RequestError:
pass
time.sleep(0.1)
yield {"server_url": server_url}
# Tell server to exit
server_instance.should_exit = True
server_thread.join(timeout=5)