diff --git a/tests/integration/inference/test_rerank.py b/tests/integration/inference/test_rerank.py index ea17a54cb..4931c3d6c 100644 --- a/tests/integration/inference/test_rerank.py +++ b/tests/integration/inference/test_rerank.py @@ -26,10 +26,15 @@ DUMMY_IMAGE_URL = ImageContentItem( ) DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image") -SUPPORTED_PROVIDERS = {"remote::nvidia"} PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models +def skip_if_provider_doesnt_support_rerank(inference_provider_type): + supported_providers = {"remote::nvidia"} + if inference_provider_type not in supported_providers: + pytest.skip(f"{inference_provider_type} doesn't support rerank models") + + def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None: """ Validate that a rerank response has the correct structure and ordering. @@ -90,8 +95,7 @@ def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, e ], ) def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type): - if inference_provider_type not in SUPPORTED_PROVIDERS: - pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ") + skip_if_provider_doesnt_support_rerank(inference_provider_type) response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items) assert isinstance(response, list) @@ -118,8 +122,7 @@ def test_rerank_text(client_with_models, rerank_model_id, query, items, inferenc ], ) def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type): - if inference_provider_type not in SUPPORTED_PROVIDERS: - pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ") + skip_if_provider_doesnt_support_rerank(inference_provider_type) if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA: error_type = ( @@ -136,8 +139,7 @@ def test_rerank_image(client_with_models, rerank_model_id, query, items, inferen def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type): - if inference_provider_type not in SUPPORTED_PROVIDERS: - pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ") + skip_if_provider_doesnt_support_rerank(inference_provider_type) items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2] max_num_results = 2 @@ -155,8 +157,7 @@ def test_rerank_max_results(client_with_models, rerank_model_id, inference_provi def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type): - if inference_provider_type not in SUPPORTED_PROVIDERS: - pytest.xfail(f"{inference_provider_type} doesn't support rerank yet") + skip_if_provider_doesnt_support_rerank(inference_provider_type) items = [DUMMY_STRING, DUMMY_STRING2] response = client_with_models.inference.rerank( @@ -205,8 +206,7 @@ def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_i def test_rerank_semantic_correctness( client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type ): - if inference_provider_type not in SUPPORTED_PROVIDERS: - pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet.") + skip_if_provider_doesnt_support_rerank(inference_provider_type) response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)