diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index ed6683c72..375cfe82d 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -650,16 +650,22 @@ async def test_health_status_success(vllm_inference_adapter): This test verifies that the health method returns a HealthResponse with status OK, only when the connection to the vLLM server is successful. """ - # Mock the requests.get method to return a successful response - with patch('requests.get') as mock_get: - mock_response = MagicMock() - mock_response.status_code = 200 - mock_get.return_value = mock_response + # Mock the client.models.list method to return successfully + # Set vllm_inference_adapter.client to None to ensure _create_client is called + vllm_inference_adapter.client = None + with patch.object(vllm_inference_adapter, '_create_client') as mock_create_client: + # Create mock client and models + mock_client = MagicMock() + mock_models = MagicMock() + mock_client.models = mock_models + mock_create_client.return_value = mock_client # Call the health method health_response = await vllm_inference_adapter.health() # Verify the response assert health_response["status"] == HealthStatus.OK - mock_get.assert_called_once() + + # Verify that models.list was called + mock_models.list.assert_called_once() @pytest.mark.asyncio @@ -670,12 +676,19 @@ async def test_health_status_failure(vllm_inference_adapter): This test verifies that the health method returns a HealthResponse with status ERROR and an appropriate error message when the connection to the vLLM server fails. """ - # Mock the requests.get method to raise an exception - with patch('requests.get') as mock_get: - mock_get.side_effect = Exception("Connection failed") + vllm_inference_adapter.client = None + # Mock the client.models.list method to raise an exception + with patch.object(vllm_inference_adapter, '_create_client') as mock_create_client: + # Create mock client and models + mock_client = MagicMock() + mock_models = MagicMock() + mock_models.list.side_effect = Exception("Connection failed") + mock_client.models = mock_models + mock_create_client.return_value = mock_client + # Call the health method health_response = await vllm_inference_adapter.health() # Verify the response assert health_response["status"] == HealthStatus.ERROR assert "Health check failed: Connection failed" in health_response["message"] - # Verify that requests.get was called - mock_get.assert_called_once() + + mock_models.list.assert_called_once()