Add rerank models to the dynamic model list; Fix integration tests

This commit is contained in:
Jiayi 2025-09-28 14:45:16 -07:00
parent 3538477070
commit 816b68fdc7
8 changed files with 247 additions and 25 deletions

View file

@ -6,7 +6,7 @@
import pytest
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
from llama_stack_client.types import RerankResponse
from llama_stack_client.types import InferenceRerankResponse
from llama_stack_client.types.shared.interleaved_content import (
ImageContentItem,
ImageContentItemImage,
@ -30,12 +30,12 @@ SUPPORTED_PROVIDERS = {"remote::nvidia"}
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
def _validate_rerank_response(response: RerankResponse, items: list) -> None:
def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None:
"""
Validate that a rerank response has the correct structure and ordering.
Args:
response: The RerankResponse to validate
response: The InferenceRerankResponse to validate
items: The original items list that was ranked
Raises:
@ -43,7 +43,7 @@ def _validate_rerank_response(response: RerankResponse, items: list) -> None:
"""
seen = set()
last_score = float("inf")
for d in response.data:
for d in response:
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
assert d.index not in seen, f"Duplicate index {d.index} found"
seen.add(d.index)
@ -52,22 +52,22 @@ def _validate_rerank_response(response: RerankResponse, items: list) -> None:
last_score = d.relevance_score
def _validate_semantic_ranking(response: RerankResponse, items: list, expected_first_item: str) -> None:
def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, expected_first_item: str) -> None:
"""
Validate that the expected most relevant item ranks first.
Args:
response: The RerankResponse to validate
response: The InferenceRerankResponse to validate
items: The original items list that was ranked
expected_first_item: The expected first item in the ranking
Raises:
AssertionError: If any validation fails
"""
if not response.data:
if not response:
raise AssertionError("No ranking data returned in response")
actual_first_index = response.data[0].index
actual_first_index = response[0].index
actual_first_item = items[actual_first_index]
assert actual_first_item == expected_first_item, (
f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead."
@ -94,8 +94,9 @@ def test_rerank_text(client_with_models, rerank_model_id, query, items, inferenc
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items)
assert isinstance(response, list)
# TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client.
assert len(response) <= len(items)
_validate_rerank_response(response, items)
@ -129,8 +130,8 @@ def test_rerank_image(client_with_models, rerank_model_id, query, items, inferen
else:
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items)
assert isinstance(response, list)
assert len(response) <= len(items)
_validate_rerank_response(response, items)
@ -148,8 +149,8 @@ def test_rerank_max_results(client_with_models, rerank_model_id, inference_provi
max_num_results=max_num_results,
)
assert isinstance(response, RerankResponse)
assert len(response.data) == max_num_results
assert isinstance(response, list)
assert len(response) == max_num_results
_validate_rerank_response(response, items)
@ -165,8 +166,8 @@ def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_i
max_num_results=10, # Larger than items length
)
assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items) # Should return at most len(items)
assert isinstance(response, list)
assert len(response) <= len(items) # Should return at most len(items)
@pytest.mark.parametrize(