forked from phoenix-oss/llama-stack-mirror
fix: solve unregister_toolgroup error (#1608)
# What does this PR do? Fixes issue #1537 that causes "500 Internal Server Error" when unregistering a toolgroup # (Closes #1537 ) ## Test Plan ```console $ pytest -s -v tests/integration/tool_runtime/test_registration.py --stack-config=ollama --env INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" INFO 2025-03-14 21:15:03,999 tests.integration.conftest:41 tests: Setting DISABLE_CODE_SANDBOX=1 for macOS /opt/homebrew/lib/python3.10/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ===================================================== test session starts ===================================================== platform darwin -- Python 3.10.16, pytest-8.3.5, pluggy-1.5.0 -- /opt/homebrew/opt/python@3.10/bin/python3.10 cachedir: .pytest_cache rootdir: /Users/paolo/Projects/aiplatform/llama-stack configfile: pyproject.toml plugins: asyncio-0.25.3, anyio-4.8.0 asyncio: mode=strict, asyncio_default_fixture_loop_scope=None collected 1 item tests/integration/tool_runtime/test_registration.py::test_register_and_unregister_toolgroup[None-None-None-None-None] INFO 2025-03-14 21:15:04,478 llama_stack.providers.remote.inference.ollama.ollama:75 inference: checking connectivity to Ollama at `http://localhost:11434`... INFO 2025-03-14 21:15:05,350 llama_stack.providers.remote.inference.ollama.ollama:294 inference: Pulling embedding model `all-minilm:latest` if necessary... INFO: Started server process [78391] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) INFO: 127.0.0.1:57424 - "GET /sse HTTP/1.1" 200 OK INFO: 127.0.0.1:57434 - "GET /sse HTTP/1.1" 200 OK INFO 2025-03-14 21:15:16,129 mcp.client.sse:51 uncategorized: Connecting to SSE endpoint: http://localhost:8000/sse INFO: 127.0.0.1:57445 - "GET /sse HTTP/1.1" 200 OK INFO 2025-03-14 21:15:16,146 mcp.client.sse:71 uncategorized: Received endpoint URL: http://localhost:8000/messages/?session_id=c5b6fc01f8dc4b5e80e38eb1c1b22a9b INFO 2025-03-14 21:15:16,147 mcp.client.sse:140 uncategorized: Starting post writer with endpoint URL: http://localhost:8000/messages/?session_id=c5b6fc01f8dc4b5e80e38eb1c1b22a9b INFO: 127.0.0.1:57447 - "POST /messages/?session_id=c5b6fc01f8dc4b5e80e38eb1c1b22a9b HTTP/1.1" 202 Accepted INFO: 127.0.0.1:57447 - "POST /messages/?session_id=c5b6fc01f8dc4b5e80e38eb1c1b22a9b HTTP/1.1" 202 Accepted INFO: 127.0.0.1:57447 - "POST /messages/?session_id=c5b6fc01f8dc4b5e80e38eb1c1b22a9b HTTP/1.1" 202 Accepted INFO 2025-03-14 21:15:16,155 mcp.server.lowlevel.server:535 uncategorized: Processing request of type ListToolsRequest PASSED =============================================== 1 passed, 4 warnings in 12.17s ================================================ ``` --------- Signed-off-by: Paolo Dettori <dettori@us.ibm.com>
This commit is contained in:
parent
a2cf299906
commit
22814299b0
2 changed files with 126 additions and 2 deletions
|
@ -608,8 +608,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
tool_group = await self.get_tool_group(toolgroup_id)
|
tool_group = await self.get_tool_group(toolgroup_id)
|
||||||
if tool_group is None:
|
if tool_group is None:
|
||||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||||
tools = (await self.list_tools(toolgroup_id)).data
|
tools = await self.list_tools(toolgroup_id)
|
||||||
for tool in tools:
|
for tool in getattr(tools, "data", []):
|
||||||
await self.unregister_object(tool)
|
await self.unregister_object(tool)
|
||||||
await self.unregister_object(tool_group)
|
await self.unregister_object(tool_group)
|
||||||
|
|
||||||
|
|
124
tests/integration/tool_runtime/test_registration.py
Normal file
124
tests/integration/tool_runtime/test_registration.py
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import mcp.types as types
|
||||||
|
import pytest
|
||||||
|
import uvicorn
|
||||||
|
from llama_stack_client.types.shared_params.url import URL
|
||||||
|
from mcp.server.fastmcp import Context, FastMCP
|
||||||
|
from mcp.server.sse import SseServerTransport
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.routing import Mount, Route
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mcp_server():
|
||||||
|
server = FastMCP("FastMCP Test Server")
|
||||||
|
|
||||||
|
@server.tool()
|
||||||
|
async def fetch(url: str, ctx: Context) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
||||||
|
headers = {"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"}
|
||||||
|
async with httpx.AsyncClient(follow_redirects=True, headers=headers) as client:
|
||||||
|
response = await client.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
|
return [types.TextContent(type="text", text=response.text)]
|
||||||
|
|
||||||
|
sse = SseServerTransport("/messages/")
|
||||||
|
|
||||||
|
async def handle_sse(request):
|
||||||
|
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
|
||||||
|
await server._mcp_server.run(
|
||||||
|
streams[0],
|
||||||
|
streams[1],
|
||||||
|
server._mcp_server.create_initialization_options(),
|
||||||
|
)
|
||||||
|
|
||||||
|
app = Starlette(
|
||||||
|
debug=True,
|
||||||
|
routes=[
|
||||||
|
Route("/sse", endpoint=handle_sse),
|
||||||
|
Mount("/messages/", app=sse.handle_post_message),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_open_port():
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||||
|
sock.bind(("", 0))
|
||||||
|
return sock.getsockname()[1]
|
||||||
|
|
||||||
|
port = get_open_port()
|
||||||
|
|
||||||
|
def run_server():
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||||
|
|
||||||
|
# Start the server in a new thread
|
||||||
|
server_thread = threading.Thread(target=run_server, daemon=True)
|
||||||
|
server_thread.start()
|
||||||
|
|
||||||
|
# Polling until the server is ready
|
||||||
|
timeout = 10
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
try:
|
||||||
|
response = httpx.get(f"http://localhost:{port}/sse")
|
||||||
|
if response.status_code == 200:
|
||||||
|
break
|
||||||
|
except (httpx.RequestError, httpx.HTTPStatusError):
|
||||||
|
pass
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
yield port
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server):
|
||||||
|
"""
|
||||||
|
Integration test for registering and unregistering a toolgroup using the ToolGroups API.
|
||||||
|
"""
|
||||||
|
port = mcp_server
|
||||||
|
test_toolgroup_id = "remote::web-fetch"
|
||||||
|
provider_id = "model-context-protocol"
|
||||||
|
|
||||||
|
# Cleanup before running the test
|
||||||
|
toolgroups = llama_stack_client.toolgroups.list()
|
||||||
|
for toolgroup in toolgroups:
|
||||||
|
if toolgroup.identifier == test_toolgroup_id:
|
||||||
|
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
|
||||||
|
|
||||||
|
# Register the toolgroup
|
||||||
|
llama_stack_client.toolgroups.register(
|
||||||
|
toolgroup_id=test_toolgroup_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
mcp_endpoint=URL(uri=f"http://localhost:{port}/sse"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify registration
|
||||||
|
registered_toolgroup = llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
|
||||||
|
assert registered_toolgroup is not None
|
||||||
|
assert registered_toolgroup.identifier == test_toolgroup_id
|
||||||
|
assert registered_toolgroup.provider_id == provider_id
|
||||||
|
|
||||||
|
# Verify tools listing
|
||||||
|
tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
|
||||||
|
assert isinstance(tools_list_response, list)
|
||||||
|
assert tools_list_response
|
||||||
|
|
||||||
|
# Unregister the toolgroup
|
||||||
|
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
|
||||||
|
|
||||||
|
# Verify it is unregistered
|
||||||
|
with pytest.raises(ValueError, match=f"Tool group '{test_toolgroup_id}' not found"):
|
||||||
|
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
|
||||||
|
|
||||||
|
# Verify tools are also unregistered
|
||||||
|
unregister_tools_list_response = llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
|
||||||
|
assert isinstance(unregister_tools_list_response, list)
|
||||||
|
assert not unregister_tools_list_response
|
Loading…
Add table
Add a link
Reference in a new issue