diff --git a/tests/integration/inference/test_rerank.py b/tests/integration/inference/test_rerank.py index 27f3074ad..f1b9311a4 100644 --- a/tests/integration/inference/test_rerank.py +++ b/tests/integration/inference/test_rerank.py @@ -52,6 +52,28 @@ def _validate_rerank_response(response: RerankResponse, items: list) -> None: last_score = d.relevance_score +def _validate_semantic_ranking(response: RerankResponse, items: list, expected_first_item: str) -> None: + """ + Validate that the expected most relevant item ranks first. + + Args: + response: The RerankResponse to validate + items: The original items list that was ranked + expected_first_item: The expected first item in the ranking + + Raises: + AssertionError: If any validation fails + """ + if not response.data: + raise AssertionError("No ranking data returned in response") + + actual_first_index = response.data[0].index + actual_first_item = items[actual_first_index] + assert actual_first_item == expected_first_item, ( + f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead." + ) + + @pytest.mark.parametrize( "query,items", [ @@ -145,3 +167,47 @@ def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_i assert isinstance(response, RerankResponse) assert len(response.data) <= len(items) # Should return at most len(items) + + +@pytest.mark.parametrize( + "query,items,expected_first_item", + [ + ( + "What is a reranking model? ", + [ + "A reranking model reranks a list of items based on the query. ", + "Machine learning algorithms learn patterns from data. ", + "Python is a programming language. ", + ], + "A reranking model reranks a list of items based on the query. ", + ), + ( + "What is C++?", + [ + "Learning new things is interesting. ", + "C++ is a programming language. ", + "Books provide knowledge and entertainment. ", + ], + "C++ is a programming language. ", + ), + ( + "What are good learning habits? ", + [ + "Cooking pasta is a fun activity. ", + "Plants need water and sunlight. ", + "Good learning habits include reading daily and taking notes. ", + ], + "Good learning habits include reading daily and taking notes. ", + ), + ], +) +def test_rerank_semantic_correctness( + client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type +): + if inference_provider_type not in SUPPORTED_PROVIDERS: + pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet.") + + response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items) + + _validate_rerank_response(response, items) + _validate_semantic_ranking(response, items, expected_first_item)