feat(responses): implement full multi-turn support (#2295)

I think the implementation needs more simplification. Spent way too much
time trying to get the tests pass with models not co-operating :(
Finally had to switch claude-sonnet to get things to pass reliably.

### Test Plan

```
export TAVILY_SEARCH_API_KEY=...
export OPENAI_API_KEY=...

uv run pytest -p no:warnings \
   -s -v tests/verifications/openai_api/test_responses.py \
 --provider=stack:starter \
  --model openai/gpt-4o
```
This commit is contained in:
Ashwin Bharambe 2025-06-02 15:35:49 -07:00 committed by GitHub
parent cac7d404a2
commit dbe4e84aca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 593 additions and 136 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
# we want the mcp server to be authenticated OR not, depends
from collections.abc import Callable
from contextlib import contextmanager
# Unfortunately the toolgroup id must be tied to the tool names because the registry
@ -13,15 +14,158 @@ from contextlib import contextmanager
MCP_TOOLGROUP_ID = "mcp::localmcp"
def default_tools():
"""Default tools for backward compatibility."""
from mcp import types
from mcp.server.fastmcp import Context
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
async def get_boiling_point(liquid_name: str, celsius: bool = True) -> int:
"""
Returns the boiling point of a liquid in Celsius or Fahrenheit.
:param liquid_name: The name of the liquid
:param celsius: Whether to return the boiling point in Celsius
:return: The boiling point of the liquid in Celcius or Fahrenheit
"""
if liquid_name.lower() == "myawesomeliquid":
if celsius:
return -100
else:
return -212
else:
return -1
return {"greet_everyone": greet_everyone, "get_boiling_point": get_boiling_point}
def dependency_tools():
"""Tools with natural dependencies for multi-turn testing."""
from mcp import types
from mcp.server.fastmcp import Context
async def get_user_id(username: str, ctx: Context) -> str:
"""
Get the user ID for a given username. This ID is needed for other operations.
:param username: The username to look up
:return: The user ID for the username
"""
# Simple mapping for testing
user_mapping = {"alice": "user_12345", "bob": "user_67890", "charlie": "user_11111", "admin": "user_00000"}
return user_mapping.get(username.lower(), "user_99999")
async def get_user_permissions(user_id: str, ctx: Context) -> str:
"""
Get the permissions for a user ID. Requires a valid user ID from get_user_id.
:param user_id: The user ID to check permissions for
:return: The permissions for the user
"""
# Permission mapping based on user IDs
permission_mapping = {
"user_12345": "read,write", # alice
"user_67890": "read", # bob
"user_11111": "admin", # charlie
"user_00000": "superadmin", # admin
"user_99999": "none", # unknown users
}
return permission_mapping.get(user_id, "none")
async def check_file_access(user_id: str, filename: str, ctx: Context) -> str:
"""
Check if a user can access a specific file. Requires a valid user ID.
:param user_id: The user ID to check access for
:param filename: The filename to check access to
:return: Whether the user can access the file (yes/no)
"""
# Get permissions first
permission_mapping = {
"user_12345": "read,write", # alice
"user_67890": "read", # bob
"user_11111": "admin", # charlie
"user_00000": "superadmin", # admin
"user_99999": "none", # unknown users
}
permissions = permission_mapping.get(user_id, "none")
# Check file access based on permissions and filename
if permissions == "superadmin":
access = "yes"
elif permissions == "admin":
access = "yes" if not filename.startswith("secret_") else "no"
elif "write" in permissions:
access = "yes" if filename.endswith(".txt") else "no"
elif "read" in permissions:
access = "yes" if filename.endswith(".txt") or filename.endswith(".md") else "no"
else:
access = "no"
return [types.TextContent(type="text", text=access)]
async def get_experiment_id(experiment_name: str, ctx: Context) -> str:
"""
Get the experiment ID for a given experiment name. This ID is needed to get results.
:param experiment_name: The name of the experiment
:return: The experiment ID
"""
# Simple mapping for testing
experiment_mapping = {
"temperature_test": "exp_001",
"pressure_test": "exp_002",
"chemical_reaction": "exp_003",
"boiling_point": "exp_004",
}
exp_id = experiment_mapping.get(experiment_name.lower(), "exp_999")
return exp_id
async def get_experiment_results(experiment_id: str, ctx: Context) -> str:
"""
Get the results for an experiment ID. Requires a valid experiment ID from get_experiment_id.
:param experiment_id: The experiment ID to get results for
:return: The experiment results
"""
# Results mapping based on experiment IDs
results_mapping = {
"exp_001": "Temperature: 25°C, Status: Success",
"exp_002": "Pressure: 1.2 atm, Status: Success",
"exp_003": "Yield: 85%, Status: Complete",
"exp_004": "Boiling Point: 100°C, Status: Verified",
"exp_999": "No results found",
}
results = results_mapping.get(experiment_id, "Invalid experiment ID")
return results
return {
"get_user_id": get_user_id,
"get_user_permissions": get_user_permissions,
"check_file_access": check_file_access,
"get_experiment_id": get_experiment_id,
"get_experiment_results": get_experiment_results,
}
@contextmanager
def make_mcp_server(required_auth_token: str | None = None):
def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Callable] | None = None):
"""
Create an MCP server with the specified tools.
:param required_auth_token: Optional auth token required for access
:param tools: Dictionary of tool_name -> tool_function. If None, uses default tools.
"""
import threading
import time
import httpx
import uvicorn
from mcp import types
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.fastmcp import FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.responses import Response
@ -29,35 +173,18 @@ def make_mcp_server(required_auth_token: str | None = None):
server = FastMCP("FastMCP Test Server", log_level="WARNING")
@server.tool()
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
tools = tools or default_tools()
@server.tool()
async def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
"""
Returns the boiling point of a liquid in Celcius or Fahrenheit.
:param liquid_name: The name of the liquid
:param celcius: Whether to return the boiling point in Celcius
:return: The boiling point of the liquid in Celcius or Fahrenheit
"""
if liquid_name.lower() == "polyjuice":
if celcius:
return -100
else:
return -212
else:
return -1
# Register all tools with the server
for tool_func in tools.values():
server.tool()(tool_func)
sse = SseServerTransport("/messages/")
async def handle_sse(request):
from starlette.exceptions import HTTPException
auth_header = request.headers.get("Authorization")
auth_header: str | None = request.headers.get("Authorization")
auth_token = None
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.split(" ")[1]