Fix rerank integration test based on client side changes

This commit is contained in:
Jiayi 2025-10-01 10:37:58 -07:00
parent bb2eb33fc3
commit 6b4940806f
8 changed files with 27 additions and 276 deletions

View file

@ -6,7 +6,7 @@
import pytest
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
from llama_stack_client.types import InferenceRerankResponse
from llama_stack_client.types.alpha import InferenceRerankResponse
from llama_stack_client.types.shared.interleaved_content import (
ImageContentItem,
ImageContentItemImage,
@ -97,7 +97,7 @@ def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, e
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, list)
# TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client.
assert len(response) <= len(items)
@ -129,9 +129,9 @@ def test_rerank_image(client_with_models, rerank_model_id, query, items, inferen
ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
)
with pytest.raises(error_type):
client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
else:
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, list)
assert len(response) <= len(items)
@ -144,7 +144,7 @@ def test_rerank_max_results(client_with_models, rerank_model_id, inference_provi
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
max_num_results = 2
response = client_with_models.inference.rerank(
response = client_with_models.alpha.inference.rerank(
model=rerank_model_id,
query=DUMMY_STRING,
items=items,
@ -160,7 +160,7 @@ def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_i
skip_if_provider_doesnt_support_rerank(inference_provider_type)
items = [DUMMY_STRING, DUMMY_STRING2]
response = client_with_models.inference.rerank(
response = client_with_models.alpha.inference.rerank(
model=rerank_model_id,
query=DUMMY_STRING,
items=items,
@ -208,7 +208,7 @@ def test_rerank_semantic_correctness(
):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
_validate_rerank_response(response, items)
_validate_semantic_ranking(response, items, expected_first_item)