Add rerank semantic validation tests

This commit is contained in:
Jiayi 2025-09-10 17:18:59 -07:00
parent d1b4e090ef
commit a0e6e82c1e

View file

@ -52,6 +52,28 @@ def _validate_rerank_response(response: RerankResponse, items: list) -> None:
last_score = d.relevance_score
def _validate_semantic_ranking(response: RerankResponse, items: list, expected_first_item: str) -> None:
"""
Validate that the expected most relevant item ranks first.
Args:
response: The RerankResponse to validate
items: The original items list that was ranked
expected_first_item: The expected first item in the ranking
Raises:
AssertionError: If any validation fails
"""
if not response.data:
raise AssertionError("No ranking data returned in response")
actual_first_index = response.data[0].index
actual_first_item = items[actual_first_index]
assert actual_first_item == expected_first_item, (
f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead."
)
@pytest.mark.parametrize(
"query,items",
[
@ -145,3 +167,47 @@ def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_i
assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items) # Should return at most len(items)
@pytest.mark.parametrize(
"query,items,expected_first_item",
[
(
"What is a reranking model? ",
[
"A reranking model reranks a list of items based on the query. ",
"Machine learning algorithms learn patterns from data. ",
"Python is a programming language. ",
],
"A reranking model reranks a list of items based on the query. ",
),
(
"What is C++?",
[
"Learning new things is interesting. ",
"C++ is a programming language. ",
"Books provide knowledge and entertainment. ",
],
"C++ is a programming language. ",
),
(
"What are good learning habits? ",
[
"Cooking pasta is a fun activity. ",
"Plants need water and sunlight. ",
"Good learning habits include reading daily and taking notes. ",
],
"Good learning habits include reading daily and taking notes. ",
),
],
)
def test_rerank_semantic_correctness(
client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type
):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet.")
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
_validate_rerank_response(response, items)
_validate_semantic_ranking(response, items, expected_first_item)