diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 2414522a7..04f2139cc 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -24,7 +24,7 @@ jobs: matrix: # Listing tests manually since some of them currently fail # TODO: generate matrix list from tests/integration when fixed - test-type: [agents, inference, datasets, inspect, scoring, post_training, providers] + test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime] client-type: [library, http] fail-fast: false # we want to run all tests regardless of failure diff --git a/.gitignore b/.gitignore index 0ef25cdf1..2cc885604 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ dev_requirements.txt build .DS_Store llama_stack/configs/* +.cursor/ xcuserdata/ *.hmap .DS_Store diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index afa4df766..7bc372f07 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -7,7 +7,14 @@ from contextlib import asynccontextmanager from typing import Any -import exceptiongroup +try: + # for python < 3.11 + import exceptiongroup + + BaseExceptionGroup = exceptiongroup.BaseExceptionGroup +except ImportError: + pass + import httpx from mcp import ClientSession from mcp import types as mcp_types @@ -34,7 +41,7 @@ async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): await session.initialize() yield session except BaseException as e: - if isinstance(e, exceptiongroup.BaseExceptionGroup): + if isinstance(e, BaseExceptionGroup): for exc in e.exceptions: if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 401: raise AuthenticationRequiredError(exc) from exc diff --git a/tests/common/mcp.py b/tests/common/mcp.py index f602cbff2..b66653ace 100644 --- a/tests/common/mcp.py +++ b/tests/common/mcp.py @@ -109,8 +109,31 @@ def make_mcp_server(required_auth_token: str | None = None): pass time.sleep(0.1) - yield {"server_url": server_url} + try: + yield {"server_url": server_url} + finally: + print("Telling SSE server to exit") + server_instance.should_exit = True + time.sleep(0.5) - # Tell server to exit - server_instance.should_exit = True - server_thread.join(timeout=5) + # Force shutdown if still running + if server_thread.is_alive(): + try: + if hasattr(server_instance, "servers") and server_instance.servers: + for srv in server_instance.servers: + srv.close() + + # Wait for graceful shutdown + server_thread.join(timeout=3) + if server_thread.is_alive(): + print("Warning: Server thread still alive after shutdown attempt") + except Exception as e: + print(f"Error during server shutdown: {e}") + + # CRITICAL: Reset SSE global state to prevent event loop contamination + # Reset the SSE AppStatus singleton that stores anyio.Event objects + from sse_starlette.sse import AppStatus + + AppStatus.should_exit = False + AppStatus.should_exit_event = None + print("SSE server exited")