Add skip_if_provider_doesnt_support_rerank

This commit is contained in:
Jiayi 2025-09-29 11:30:46 -07:00
parent 2fb8756fe2
commit f2a398dcba

View file

@ -26,10 +26,15 @@ DUMMY_IMAGE_URL = ImageContentItem(
)
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
SUPPORTED_PROVIDERS = {"remote::nvidia"}
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
def skip_if_provider_doesnt_support_rerank(inference_provider_type):
supported_providers = {"remote::nvidia"}
if inference_provider_type not in supported_providers:
pytest.skip(f"{inference_provider_type} doesn't support rerank models")
def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None:
"""
Validate that a rerank response has the correct structure and ordering.
@ -90,8 +95,7 @@ def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, e
],
)
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
skip_if_provider_doesnt_support_rerank(inference_provider_type)
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, list)
@ -118,8 +122,7 @@ def test_rerank_text(client_with_models, rerank_model_id, query, items, inferenc
],
)
def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
skip_if_provider_doesnt_support_rerank(inference_provider_type)
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
error_type = (
@ -136,8 +139,7 @@ def test_rerank_image(client_with_models, rerank_model_id, query, items, inferen
def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
skip_if_provider_doesnt_support_rerank(inference_provider_type)
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
max_num_results = 2
@ -155,8 +157,7 @@ def test_rerank_max_results(client_with_models, rerank_model_id, inference_provi
def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank yet")
skip_if_provider_doesnt_support_rerank(inference_provider_type)
items = [DUMMY_STRING, DUMMY_STRING2]
response = client_with_models.inference.rerank(
@ -205,8 +206,7 @@ def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_i
def test_rerank_semantic_correctness(
client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type
):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet.")
skip_if_provider_doesnt_support_rerank(inference_provider_type)
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)