mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Add skip_if_provider_doesnt_support_rerank
This commit is contained in:
parent
2fb8756fe2
commit
f2a398dcba
1 changed files with 11 additions and 11 deletions
|
@ -26,10 +26,15 @@ DUMMY_IMAGE_URL = ImageContentItem(
|
|||
)
|
||||
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
|
||||
|
||||
SUPPORTED_PROVIDERS = {"remote::nvidia"}
|
||||
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
|
||||
|
||||
|
||||
def skip_if_provider_doesnt_support_rerank(inference_provider_type):
|
||||
supported_providers = {"remote::nvidia"}
|
||||
if inference_provider_type not in supported_providers:
|
||||
pytest.skip(f"{inference_provider_type} doesn't support rerank models")
|
||||
|
||||
|
||||
def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None:
|
||||
"""
|
||||
Validate that a rerank response has the correct structure and ordering.
|
||||
|
@ -90,8 +95,7 @@ def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, e
|
|||
],
|
||||
)
|
||||
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
assert isinstance(response, list)
|
||||
|
@ -118,8 +122,7 @@ def test_rerank_text(client_with_models, rerank_model_id, query, items, inferenc
|
|||
],
|
||||
)
|
||||
def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
|
||||
error_type = (
|
||||
|
@ -136,8 +139,7 @@ def test_rerank_image(client_with_models, rerank_model_id, query, items, inferen
|
|||
|
||||
|
||||
def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
|
||||
max_num_results = 2
|
||||
|
@ -155,8 +157,7 @@ def test_rerank_max_results(client_with_models, rerank_model_id, inference_provi
|
|||
|
||||
|
||||
def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank yet")
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
items = [DUMMY_STRING, DUMMY_STRING2]
|
||||
response = client_with_models.inference.rerank(
|
||||
|
@ -205,8 +206,7 @@ def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_i
|
|||
def test_rerank_semantic_correctness(
|
||||
client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type
|
||||
):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet.")
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue