mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Add rerank models to the dynamic model list; Fix integration tests
This commit is contained in:
parent
3538477070
commit
816b68fdc7
8 changed files with 247 additions and 25 deletions
|
@ -6,7 +6,7 @@
|
|||
|
||||
import pytest
|
||||
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
||||
from llama_stack_client.types import RerankResponse
|
||||
from llama_stack_client.types import InferenceRerankResponse
|
||||
from llama_stack_client.types.shared.interleaved_content import (
|
||||
ImageContentItem,
|
||||
ImageContentItemImage,
|
||||
|
@ -30,12 +30,12 @@ SUPPORTED_PROVIDERS = {"remote::nvidia"}
|
|||
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
|
||||
|
||||
|
||||
def _validate_rerank_response(response: RerankResponse, items: list) -> None:
|
||||
def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None:
|
||||
"""
|
||||
Validate that a rerank response has the correct structure and ordering.
|
||||
|
||||
Args:
|
||||
response: The RerankResponse to validate
|
||||
response: The InferenceRerankResponse to validate
|
||||
items: The original items list that was ranked
|
||||
|
||||
Raises:
|
||||
|
@ -43,7 +43,7 @@ def _validate_rerank_response(response: RerankResponse, items: list) -> None:
|
|||
"""
|
||||
seen = set()
|
||||
last_score = float("inf")
|
||||
for d in response.data:
|
||||
for d in response:
|
||||
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
|
||||
assert d.index not in seen, f"Duplicate index {d.index} found"
|
||||
seen.add(d.index)
|
||||
|
@ -52,22 +52,22 @@ def _validate_rerank_response(response: RerankResponse, items: list) -> None:
|
|||
last_score = d.relevance_score
|
||||
|
||||
|
||||
def _validate_semantic_ranking(response: RerankResponse, items: list, expected_first_item: str) -> None:
|
||||
def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, expected_first_item: str) -> None:
|
||||
"""
|
||||
Validate that the expected most relevant item ranks first.
|
||||
|
||||
Args:
|
||||
response: The RerankResponse to validate
|
||||
response: The InferenceRerankResponse to validate
|
||||
items: The original items list that was ranked
|
||||
expected_first_item: The expected first item in the ranking
|
||||
|
||||
Raises:
|
||||
AssertionError: If any validation fails
|
||||
"""
|
||||
if not response.data:
|
||||
if not response:
|
||||
raise AssertionError("No ranking data returned in response")
|
||||
|
||||
actual_first_index = response.data[0].index
|
||||
actual_first_index = response[0].index
|
||||
actual_first_item = items[actual_first_index]
|
||||
assert actual_first_item == expected_first_item, (
|
||||
f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead."
|
||||
|
@ -94,8 +94,9 @@ def test_rerank_text(client_with_models, rerank_model_id, query, items, inferenc
|
|||
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
|
||||
|
||||
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
assert isinstance(response, RerankResponse)
|
||||
assert len(response.data) <= len(items)
|
||||
assert isinstance(response, list)
|
||||
# TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client.
|
||||
assert len(response) <= len(items)
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
|
@ -129,8 +130,8 @@ def test_rerank_image(client_with_models, rerank_model_id, query, items, inferen
|
|||
else:
|
||||
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
|
||||
assert isinstance(response, RerankResponse)
|
||||
assert len(response.data) <= len(items)
|
||||
assert isinstance(response, list)
|
||||
assert len(response) <= len(items)
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
|
@ -148,8 +149,8 @@ def test_rerank_max_results(client_with_models, rerank_model_id, inference_provi
|
|||
max_num_results=max_num_results,
|
||||
)
|
||||
|
||||
assert isinstance(response, RerankResponse)
|
||||
assert len(response.data) == max_num_results
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == max_num_results
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
|
@ -165,8 +166,8 @@ def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_i
|
|||
max_num_results=10, # Larger than items length
|
||||
)
|
||||
|
||||
assert isinstance(response, RerankResponse)
|
||||
assert len(response.data) <= len(items) # Should return at most len(items)
|
||||
assert isinstance(response, list)
|
||||
assert len(response) <= len(items) # Should return at most len(items)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue