mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Add skip_if_provider_doesnt_support_rerank
This commit is contained in:
parent
2fb8756fe2
commit
f2a398dcba
1 changed files with 11 additions and 11 deletions
|
@ -26,10 +26,15 @@ DUMMY_IMAGE_URL = ImageContentItem(
|
||||||
)
|
)
|
||||||
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
|
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
|
||||||
|
|
||||||
SUPPORTED_PROVIDERS = {"remote::nvidia"}
|
|
||||||
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
|
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_provider_doesnt_support_rerank(inference_provider_type):
|
||||||
|
supported_providers = {"remote::nvidia"}
|
||||||
|
if inference_provider_type not in supported_providers:
|
||||||
|
pytest.skip(f"{inference_provider_type} doesn't support rerank models")
|
||||||
|
|
||||||
|
|
||||||
def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None:
|
def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None:
|
||||||
"""
|
"""
|
||||||
Validate that a rerank response has the correct structure and ordering.
|
Validate that a rerank response has the correct structure and ordering.
|
||||||
|
@ -90,8 +95,7 @@ def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, e
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
|
||||||
|
|
||||||
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||||
assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
|
@ -118,8 +122,7 @@ def test_rerank_text(client_with_models, rerank_model_id, query, items, inferenc
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
|
||||||
|
|
||||||
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
|
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
|
||||||
error_type = (
|
error_type = (
|
||||||
|
@ -136,8 +139,7 @@ def test_rerank_image(client_with_models, rerank_model_id, query, items, inferen
|
||||||
|
|
||||||
|
|
||||||
def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type):
|
def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type):
|
||||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
|
||||||
|
|
||||||
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
|
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
|
||||||
max_num_results = 2
|
max_num_results = 2
|
||||||
|
@ -155,8 +157,7 @@ def test_rerank_max_results(client_with_models, rerank_model_id, inference_provi
|
||||||
|
|
||||||
|
|
||||||
def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type):
|
def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type):
|
||||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank yet")
|
|
||||||
|
|
||||||
items = [DUMMY_STRING, DUMMY_STRING2]
|
items = [DUMMY_STRING, DUMMY_STRING2]
|
||||||
response = client_with_models.inference.rerank(
|
response = client_with_models.inference.rerank(
|
||||||
|
@ -205,8 +206,7 @@ def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_i
|
||||||
def test_rerank_semantic_correctness(
|
def test_rerank_semantic_correctness(
|
||||||
client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type
|
client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type
|
||||||
):
|
):
|
||||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||||
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet.")
|
|
||||||
|
|
||||||
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue