From 9adabb09cfe6c034a7b5c7b4e9b3678f4afe05fc Mon Sep 17 00:00:00 2001 From: Kai Wu Date: Mon, 29 Sep 2025 15:21:52 -0700 Subject: [PATCH] keep related tests only --- tests/unit/providers/utils/tools/test_mcp.py | 253 ------------------- 1 file changed, 253 deletions(-) diff --git a/tests/unit/providers/utils/tools/test_mcp.py b/tests/unit/providers/utils/tools/test_mcp.py index a29192aaf..80d574f2d 100644 --- a/tests/unit/providers/utils/tools/test_mcp.py +++ b/tests/unit/providers/utils/tools/test_mcp.py @@ -256,145 +256,6 @@ class TestMCPProtocol: assert MCPProtol.SSE.value == 2 -class TestClientWrapper: - """Test cases for client_wrapper function.""" - - @pytest.fixture - def mock_client_session(self): - """Mock ClientSession for testing.""" - session = Mock() - session.initialize = AsyncMock() - return session - - @pytest.fixture - def mock_client_streams(self): - """Mock client streams.""" - return (Mock(), Mock()) - - @pytest.mark.asyncio - async def test_successful_streamable_http_connection(self, mock_client_session, mock_client_streams): - """Test successful connection with STREAMABLE_HTTP protocol.""" - endpoint = "http://example.com/mcp" - headers = {"Authorization": "Bearer token"} - - # Create a proper context manager mock - mock_http_context = AsyncMock() - mock_http_context.__aenter__ = AsyncMock(return_value=mock_client_streams) - mock_http_context.__aexit__ = AsyncMock(return_value=False) - - mock_session_context = AsyncMock() - mock_session_context.__aenter__ = AsyncMock(return_value=mock_client_session) - mock_session_context.__aexit__ = AsyncMock(return_value=False) - - with patch('llama_stack.providers.utils.tools.mcp.streamablehttp_client') as mock_http_client, \ - patch('llama_stack.providers.utils.tools.mcp.ClientSession') as mock_session_class: - - mock_http_client.return_value = mock_http_context - mock_session_class.return_value = mock_session_context - - async with client_wrapper(endpoint, headers) as session: - assert session == mock_client_session - mock_client_session.initialize.assert_called_once() - assert protocol_cache.get(endpoint) == MCPProtol.STREAMABLE_HTTP - - - @pytest.mark.asyncio - async def test_cached_protocol_preference(self, mock_client_session, mock_client_streams): - """Test that cached protocol is tried first.""" - endpoint = "http://example.com/mcp" - headers = {"Authorization": "Bearer token"} - - # Set SSE as cached protocol - protocol_cache[endpoint] = MCPProtol.SSE - - # Create proper context manager mocks - mock_sse_context = AsyncMock() - mock_sse_context.__aenter__ = AsyncMock(return_value=mock_client_streams) - mock_sse_context.__aexit__ = AsyncMock(return_value=False) - - mock_session_context = AsyncMock() - mock_session_context.__aenter__ = AsyncMock(return_value=mock_client_session) - mock_session_context.__aexit__ = AsyncMock(return_value=False) - - with patch('llama_stack.providers.utils.tools.mcp.sse_client') as mock_sse_client, \ - patch('llama_stack.providers.utils.tools.mcp.streamablehttp_client') as mock_http_client, \ - patch('llama_stack.providers.utils.tools.mcp.ClientSession') as mock_session_class: - - mock_sse_client.return_value = mock_sse_context - mock_session_class.return_value = mock_session_context - - async with client_wrapper(endpoint, headers) as session: - assert session == mock_client_session - # SSE should be tried first due to cache - mock_sse_client.assert_called_once() - mock_http_client.assert_not_called() - - @pytest.mark.asyncio - async def test_authentication_error_raises_exception(self): - """Test that 401 errors raise AuthenticationRequiredError.""" - endpoint = "http://example.com/mcp" - headers = {"Authorization": "Bearer invalid"} - - protocol_cache.clear() - - # Create a proper HTTP 401 error - response = Mock() - response.status_code = 401 - http_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=response) - - with patch('llama_stack.providers.utils.tools.mcp.streamablehttp_client') as mock_http_client: - mock_http_client.return_value.__aenter__ = AsyncMock(side_effect=http_error) - mock_http_client.return_value.__aexit__ = AsyncMock() - - with pytest.raises(AuthenticationRequiredError): - async with client_wrapper(endpoint, headers): - pass - - @pytest.mark.asyncio - async def test_connection_error_handling(self): - """Test handling of connection errors.""" - endpoint = "http://example.com/mcp" - headers = {} - - protocol_cache.clear() - - connect_error = httpx.ConnectError("Connection refused") - - with patch('llama_stack.providers.utils.tools.mcp.streamablehttp_client') as mock_http_client, \ - patch('llama_stack.providers.utils.tools.mcp.sse_client') as mock_sse_client: - - mock_http_client.return_value.__aenter__ = AsyncMock(side_effect=connect_error) - mock_http_client.return_value.__aexit__ = AsyncMock() - mock_sse_client.return_value.__aenter__ = AsyncMock(side_effect=connect_error) - mock_sse_client.return_value.__aexit__ = AsyncMock() - - with pytest.raises(ConnectionError, match="Failed to connect to MCP server"): - async with client_wrapper(endpoint, headers): - pass - - @pytest.mark.asyncio - async def test_timeout_error_handling(self): - """Test handling of timeout errors.""" - endpoint = "http://example.com/mcp" - headers = {} - - protocol_cache.clear() - - timeout_error = httpx.TimeoutException("Request timeout") - - with patch('llama_stack.providers.utils.tools.mcp.streamablehttp_client') as mock_http_client, \ - patch('llama_stack.providers.utils.tools.mcp.sse_client') as mock_sse_client: - - mock_http_client.return_value.__aenter__ = AsyncMock(side_effect=timeout_error) - mock_http_client.return_value.__aexit__ = AsyncMock() - mock_sse_client.return_value.__aenter__ = AsyncMock(side_effect=timeout_error) - mock_sse_client.return_value.__aexit__ = AsyncMock() - - with pytest.raises(TimeoutError, match="MCP server.*timed out"): - async with client_wrapper(endpoint, headers): - pass - - class TestListMcpTools: """Test cases for list_mcp_tools function.""" @@ -521,117 +382,3 @@ class TestListMcpTools: assert isinstance(result, ListToolDefsResponse) assert len(result.data) == 0 - - -class TestInvokeMcpTool: - """Test cases for invoke_mcp_tool function.""" - - @pytest.mark.asyncio - async def test_invoke_tool_success_with_text_content(self): - """Test successful tool invocation with text content.""" - endpoint = "http://example.com/mcp" - headers = {} - tool_name = "test_tool" - kwargs = {"param1": "value1", "param2": 42} - - # Mock MCP text content - mock_text_content = mcp_types.TextContent(type="text", text="Tool output text") - - mock_result = Mock() - mock_result.content = [mock_text_content] - mock_result.isError = False - - mock_session = Mock() - mock_session.call_tool = AsyncMock(return_value=mock_result) - - with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper: - mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session) - mock_wrapper.return_value.__aexit__ = AsyncMock() - - result = await invoke_mcp_tool(endpoint, headers, tool_name, kwargs) - - assert isinstance(result, ToolInvocationResult) - assert result.error_code == 0 - assert len(result.content) == 1 - - content_item = result.content[0] - assert isinstance(content_item, TextContentItem) - assert content_item.text == "Tool output text" - - mock_session.call_tool.assert_called_once_with(tool_name, kwargs) - - @pytest.mark.asyncio - async def test_invoke_tool_with_error(self): - """Test tool invocation when tool returns an error.""" - endpoint = "http://example.com/mcp" - headers = {} - tool_name = "error_tool" - kwargs = {} - - mock_text_content = mcp_types.TextContent(type="text", text="Error occurred") - - mock_result = Mock() - mock_result.content = [mock_text_content] - mock_result.isError = True - - mock_session = Mock() - mock_session.call_tool = AsyncMock(return_value=mock_result) - - with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper: - mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session) - mock_wrapper.return_value.__aexit__ = AsyncMock() - - result = await invoke_mcp_tool(endpoint, headers, tool_name, kwargs) - - assert isinstance(result, ToolInvocationResult) - assert result.error_code == 1 - assert len(result.content) == 1 - - @pytest.mark.asyncio - async def test_invoke_tool_with_embedded_resource_warning(self): - """Test tool invocation with unsupported EmbeddedResource content.""" - endpoint = "http://example.com/mcp" - headers = {} - tool_name = "resource_tool" - kwargs = {} - - # Mock MCP embedded resource content - mock_embedded_resource = mcp_types.EmbeddedResource( - type="resource", - resource={ - "uri": "file:///example.txt", - "text": "Resource content" - } - ) - - mock_result = Mock() - mock_result.content = [mock_embedded_resource] - mock_result.isError = False - - mock_session = Mock() - mock_session.call_tool = AsyncMock(return_value=mock_result) - - with patch('llama_stack.providers.utils.tools.mcp.client_wrapper') as mock_wrapper, \ - patch('llama_stack.providers.utils.tools.mcp.logger') as mock_logger: - - mock_wrapper.return_value.__aenter__ = AsyncMock(return_value=mock_session) - mock_wrapper.return_value.__aexit__ = AsyncMock() - - result = await invoke_mcp_tool(endpoint, headers, tool_name, kwargs) - - assert isinstance(result, ToolInvocationResult) - assert result.error_code == 0 - assert len(result.content) == 0 # EmbeddedResource is skipped - - # Should log a warning - mock_logger.warning.assert_called_once() - assert "EmbeddedResource is not supported" in str(mock_logger.warning.call_args) - - - -@pytest.fixture(autouse=True) -def clear_protocol_cache(): - """Clear protocol cache before each test.""" - protocol_cache.clear() - yield - protocol_cache.clear()