diff --git a/scripts/integration-tests.sh b/scripts/integration-tests.sh index 20ecd0c4d..a6f00813b 100755 --- a/scripts/integration-tests.sh +++ b/scripts/integration-tests.sh @@ -165,10 +165,18 @@ if [[ "$COLLECT_ONLY" == false ]]; then # Set MCP host for in-process MCP server tests # - For library client and server mode: localhost (both on same host) - # - For docker mode: host.docker.internal (container needs to reach host) + # - For docker mode on Linux: localhost (container uses host network, shares network namespace) + # - For docker mode on macOS/Windows: host.docker.internal (container uses bridge network) if [[ "$STACK_CONFIG" == docker:* ]]; then - export LLAMA_STACK_TEST_MCP_HOST="host.docker.internal" - echo "Setting MCP host: host.docker.internal (docker mode)" + if [[ "$(uname)" != "Darwin" ]] && [[ "$(uname)" != *"MINGW"* ]]; then + # On Linux with host network mode, container shares host network namespace + export LLAMA_STACK_TEST_MCP_HOST="localhost" + echo "Setting MCP host: localhost (docker mode with host network)" + else + # On macOS/Windows with bridge network, need special host access + export LLAMA_STACK_TEST_MCP_HOST="host.docker.internal" + echo "Setting MCP host: host.docker.internal (docker mode with bridge network)" + fi else export LLAMA_STACK_TEST_MCP_HOST="localhost" echo "Setting MCP host: localhost (library/server mode)" diff --git a/src/llama_stack/providers/utils/tools/mcp.py b/src/llama_stack/providers/utils/tools/mcp.py index 9c5e9cd96..05cdfa73b 100644 --- a/src/llama_stack/providers/utils/tools/mcp.py +++ b/src/llama_stack/providers/utils/tools/mcp.py @@ -89,6 +89,7 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat # sse_client and streamablehttp_client have different signatures, but both # are called the same way here, so we cast to Any to avoid type errors client = cast(Any, sse_client) + async with client(endpoint, headers=headers) as client_streams: async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session: await session.initialize() diff --git a/tests/unit/providers/agents/meta_reference/test_safety_optional.py b/tests/unit/providers/agents/meta_reference/test_safety_optional.py index c2311b68f..10b15b26d 100644 --- a/tests/unit/providers/agents/meta_reference/test_safety_optional.py +++ b/tests/unit/providers/agents/meta_reference/test_safety_optional.py @@ -83,7 +83,7 @@ class TestProviderInitialization: new_callable=AsyncMock, ): # Should not raise any exception - provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False) + provider = await get_provider_impl(config, mock_deps, policy=[]) assert provider is not None async def test_initialization_without_safety_api(self, mock_persistence_config, mock_deps): @@ -97,7 +97,7 @@ class TestProviderInitialization: new_callable=AsyncMock, ): # Should not raise any exception - provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False) + provider = await get_provider_impl(config, mock_deps, policy=[]) assert provider is not None assert provider.safety_api is None