more test fixes

This commit is contained in:
Ashwin Bharambe 2025-05-24 12:02:28 -07:00
parent cc6662db7b
commit a5132b4857
3 changed files with 66 additions and 208 deletions

View file

@ -7,6 +7,11 @@
# we want the mcp server to be authenticated OR not, depends
from contextlib import contextmanager
# Unfortunately the toolgroup id must be tied to the tool names because the registry
# indexes on both toolgroups and tools independently (and not jointly). That really
# needs to be fixed.
MCP_TOOLGROUP_ID = "mcp::localmcp"
@contextmanager
def make_mcp_server(required_auth_token: str | None = None):
@ -22,7 +27,7 @@ def make_mcp_server(required_auth_token: str | None = None):
from starlette.responses import Response
from starlette.routing import Mount, Route
server = FastMCP("FastMCP Test Server")
server = FastMCP("FastMCP Test Server", log_level="WARNING")
@server.tool()
async def greet_everyone(
@ -84,7 +89,8 @@ def make_mcp_server(required_auth_token: str | None = None):
port = get_open_port()
config = uvicorn.Config(app, host="0.0.0.0", port=port)
# make uvicorn logs be less verbose
config = uvicorn.Config(app, host="0.0.0.0", port=port, log_level="warning")
server_instance = uvicorn.Server(config)
app.state.uvicorn_server = server_instance

View file

@ -5,120 +5,43 @@
# the root directory of this source tree.
import json
import socket
import threading
import time
import httpx
import mcp.types as types
import pytest
import uvicorn
from llama_stack_client import Agent
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.exceptions import HTTPException
from starlette.responses import Response
from starlette.routing import Mount, Route
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import AuthenticationRequiredError
AUTH_TOKEN = "test-token"
from tests.common.mcp import MCP_TOOLGROUP_ID, make_mcp_server
@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def mcp_server():
server = FastMCP("FastMCP Test Server")
@server.tool()
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
sse = SseServerTransport("/messages/")
async def handle_sse(request):
auth_header = request.headers.get("Authorization")
auth_token = None
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.split(" ")[1]
if auth_token != AUTH_TOKEN:
raise HTTPException(status_code=401, detail="Unauthorized")
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await server._mcp_server.run(
streams[0],
streams[1],
server._mcp_server.create_initialization_options(),
)
return Response()
app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
port = get_open_port()
config = uvicorn.Config(app, host="0.0.0.0", port=port)
server_instance = uvicorn.Server(config)
app.state.uvicorn_server = server_instance
def run_server():
server_instance.run()
# Start the server in a new thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Polling until the server is ready
timeout = 10
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = httpx.get(f"http://localhost:{port}/sse")
if response.status_code == 401:
break
except httpx.RequestError:
pass
time.sleep(0.1)
yield port
# Tell server to exit
server_instance.should_exit = True
server_thread.join(timeout=5)
with make_mcp_server(required_auth_token=AUTH_TOKEN) as mcp_server_info:
yield mcp_server_info
def test_mcp_invocation(llama_stack_client, mcp_server):
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
pytest.skip("The local MCP server only reliably reachable from library client.")
port = mcp_server
test_toolgroup_id = "remote::mcptest"
test_toolgroup_id = MCP_TOOLGROUP_ID
uri = mcp_server["server_url"]
# registering itself should fail since it requires listing tools
with pytest.raises(Exception, match="Unauthorized"):
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
mcp_endpoint=dict(uri=uri),
)
provider_data = {
"mcp_headers": {
f"http://localhost:{port}/sse": [
uri: [
f"Authorization: Bearer {AUTH_TOKEN}",
],
},
@ -136,24 +59,18 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id="model-context-protocol",
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
mcp_endpoint=dict(uri=uri),
extra_headers=auth_headers,
)
response = llama_stack_client.tools.list(
toolgroup_id=test_toolgroup_id,
extra_headers=auth_headers,
)
assert len(response) == 1
assert response[0].identifier == "greet_everyone"
assert response[0].type == "tool"
assert len(response[0].parameters) == 1
p = response[0].parameters[0]
assert p.name == "url"
assert p.parameter_type == "string"
assert p.required
assert len(response) == 2
assert {t.identifier for t in response} == {"greet_everyone", "get_boiling_point"}
response = llama_stack_client.tool_runtime.invoke_tool(
tool_name=response[0].identifier,
tool_name="greet_everyone",
kwargs=dict(url="https://www.google.com"),
extra_headers=auth_headers,
)
@ -162,7 +79,9 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
assert content[0].type == "text"
assert content[0].text == "Hello, world!"
models = llama_stack_client.models.list()
models = [
m for m in llama_stack_client.models.list() if m.model_type == ModelType.llm and "guard" not in m.identifier
]
model_id = models[0].identifier
print(f"Using model: {model_id}")
agent = Agent(
@ -177,7 +96,7 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
messages=[
{
"role": "user",
"content": "Yo. Use tools.",
"content": "Say hi to the world. Use tools to do so.",
}
],
stream=False,
@ -199,7 +118,6 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
third = steps[2]
assert third.step_type == "inference"
assert len(third.api_model_response.tool_calls) == 0
# when streaming, we currently don't check auth headers upfront and fail the request
# early. but we should at least be generating a 401 later in the process.
@ -208,7 +126,7 @@ def test_mcp_invocation(llama_stack_client, mcp_server):
messages=[
{
"role": "user",
"content": "Yo. Use tools.",
"content": "What is the boiling point of polyjuice? Use tools to answer.",
}
],
stream=True,

View file

@ -4,120 +4,54 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import socket
import threading
import time
import httpx
import mcp.types as types
import pytest
import uvicorn
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.routing import Mount, Route
from llama_stack import LlamaStackAsLibraryClient
from tests.common.mcp import MCP_TOOLGROUP_ID, make_mcp_server
@pytest.fixture(scope="module")
def mcp_server():
server = FastMCP("FastMCP Test Server")
def test_register_and_unregister_toolgroup(llama_stack_client):
# TODO: make this work for http client also but you need to ensure
# the MCP server is reachable from llama stack server
if not isinstance(llama_stack_client, LlamaStackAsLibraryClient):
pytest.skip("The local MCP server only reliably reachable from library client.")
@server.tool()
async def fetch(url: str, ctx: Context) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
headers = {"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"}
async with httpx.AsyncClient(follow_redirects=True, headers=headers) as client:
response = await client.get(url)
response.raise_for_status()
return [types.TextContent(type="text", text=response.text)]
sse = SseServerTransport("/messages/")
async def handle_sse(request):
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
await server._mcp_server.run(
streams[0],
streams[1],
server._mcp_server.create_initialization_options(),
)
app = Starlette(
debug=True,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("", 0))
return sock.getsockname()[1]
port = get_open_port()
def run_server():
uvicorn.run(app, host="0.0.0.0", port=port)
# Start the server in a new thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Polling until the server is ready
timeout = 10
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = httpx.get(f"http://localhost:{port}/sse")
if response.status_code == 200:
break
except (httpx.RequestError, httpx.HTTPStatusError):
pass
time.sleep(0.1)
yield port
def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server):
"""
Integration test for registering and unregistering a toolgroup using the ToolGroups API.
"""
port = mcp_server
test_toolgroup_id = "remote::web-fetch"
test_toolgroup_id = MCP_TOOLGROUP_ID
provider_id = "model-context-protocol"
# Cleanup before running the test
toolgroups = llama_stack_client.toolgroups.list()
for toolgroup in toolgroups:
if toolgroup.identifier == test_toolgroup_id:
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
with make_mcp_server() as mcp_server_info:
# Cleanup before running the test
toolgroups = llama_stack_client.toolgroups.list()
for toolgroup in toolgroups:
if toolgroup.identifier == test_toolgroup_id:
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
# Register the toolgroup
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id=provider_id,
mcp_endpoint=dict(uri=f"http://localhost:{port}/sse"),
)
# Register the toolgroup
llama_stack_client.toolgroups.register(
toolgroup_id=test_toolgroup_id,
provider_id=provider_id,
mcp_endpoint=dict(uri=mcp_server_info["server_url"]),
)
# Verify registration
registered_toolgroup = llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
assert registered_toolgroup is not None
assert registered_toolgroup.identifier == test_toolgroup_id
assert registered_toolgroup.provider_id == provider_id
# Verify registration
registered_toolgroup = llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
assert registered_toolgroup is not None
assert registered_toolgroup.identifier == test_toolgroup_id
assert registered_toolgroup.provider_id == provider_id
# Verify tools listing
tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
assert isinstance(tools_list_response, list)
assert tools_list_response
# Verify tools listing
tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
assert isinstance(tools_list_response, list)
assert tools_list_response
# Unregister the toolgroup
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
# Unregister the toolgroup
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
# Verify it is unregistered
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
# Verify it is unregistered
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
# Verify tools are also unregistered
unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
assert isinstance(unregister_tools_list_response, list)
assert not unregister_tools_list_response
# Verify tools are also unregistered
unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
assert isinstance(unregister_tools_list_response, list)
assert not unregister_tools_list_response