From d7cbeb4b8c5942cfda4f096dd1fd45eeb35d1349 Mon Sep 17 00:00:00 2001 From: Jiayi Date: Thu, 4 Sep 2025 18:08:35 -0700 Subject: [PATCH] Add tests --- docs/docs/providers/inference/index.mdx | 6 +- .../remote/inference/nvidia/nvidia.py | 2 +- tests/integration/conftest.py | 5 + tests/integration/fixtures/common.py | 10 +- tests/integration/inference/test_rerank.py | 147 ++++++++++++++ .../providers/nvidia/test_rerank_inference.py | 180 ++++++++++++++++++ 6 files changed, 345 insertions(+), 5 deletions(-) create mode 100644 tests/integration/inference/test_rerank.py create mode 100644 tests/unit/providers/nvidia/test_rerank_inference.py diff --git a/docs/docs/providers/inference/index.mdx b/docs/docs/providers/inference/index.mdx index e96169cad..d9d30ab78 100644 --- a/docs/docs/providers/inference/index.mdx +++ b/docs/docs/providers/inference/index.mdx @@ -1,9 +1,9 @@ --- description: "Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: + This API provides the raw interface to the underlying models. Three kinds of models are supported: - LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. - - Embedding models: these models generate embeddings to be used for semantic search." + - Embedding models: these models generate embeddings to be used for semantic search. - Rerank models: these models rerank the documents by relevance." sidebar_label: Inference title: Inference @@ -15,7 +15,7 @@ title: Inference Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: + This API provides the raw interface to the underlying models. Three kinds of models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. - Rerank models: these models rerank the documents by relevance. diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index b2fdec61f..8dc5e0a11 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -151,7 +151,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference): for ranking in rankings: rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"])) - # Apply max_num_results limit if specified + # Apply max_num_results limit if max_num_results is not None: rerank_data = rerank_data[:max_num_results] diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 4735264c3..2ad4f7e4c 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -120,6 +120,10 @@ def pytest_addoption(parser): "--embedding-model", help="comma-separated list of embedding models. Fixture name: embedding_model_id", ) + parser.addoption( + "--rerank-model", + help="comma-separated list of rerank models. Fixture name: rerank_model_id", + ) parser.addoption( "--safety-shield", help="comma-separated list of safety shields. Fixture name: shield_id", @@ -198,6 +202,7 @@ def pytest_generate_tests(metafunc): "shield_id": ("--safety-shield", "shield"), "judge_model_id": ("--judge-model", "judge"), "embedding_dimension": ("--embedding-dimension", "dim"), + "rerank_model_id": ("--rerank-model", "rerank"), } # Collect all parameters and their values diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 68aa2b60b..27283afe7 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -119,6 +119,7 @@ def client_with_models( embedding_model_id, embedding_dimension, judge_model_id, + rerank_model_id, ): client = llama_stack_client @@ -151,6 +152,13 @@ def client_with_models( model_type="embedding", metadata={"embedding_dimension": embedding_dimension or 384}, ) + if rerank_model_id and rerank_model_id not in model_ids: + rerank_provider = providers[0] + client.models.register( + model_id=rerank_model_id, + provider_id=rerank_provider.provider_id, + model_type="rerank", + ) return client @@ -166,7 +174,7 @@ def model_providers(llama_stack_client): @pytest.fixture(autouse=True) def skip_if_no_model(request): - model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"] + model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id", "rerank_model_id"] test_func = request.node.function actual_params = inspect.signature(test_func).parameters.keys() diff --git a/tests/integration/inference/test_rerank.py b/tests/integration/inference/test_rerank.py new file mode 100644 index 000000000..0c536b539 --- /dev/null +++ b/tests/integration/inference/test_rerank.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +from llama_stack_client import BadRequestError as LlamaStackBadRequestError +from llama_stack_client.types import RerankResponse +from llama_stack_client.types.shared.interleaved_content import ( + ImageContentItem, + ImageContentItemImage, + ImageContentItemImageURL, + TextContentItem, +) + +from llama_stack.core.library_client import LlamaStackAsLibraryClient + +# Test data +DUMMY_STRING = "string_1" +DUMMY_STRING2 = "string_2" +DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text") +DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text") +DUMMY_IMAGE_URL = ImageContentItem( + image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image" +) +DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image") + +SUPPORTED_PROVIDERS = {"remote::nvidia"} +PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models + + +def _validate_rerank_response(response: RerankResponse, items: list) -> None: + """ + Validate that a rerank response has the correct structure and ordering. + + Args: + response: The RerankResponse to validate + items: The original items list that was ranked + + Raises: + AssertionError: If any validation fails + """ + seen = set() + last_score = float("inf") + for d in response.data: + assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items" + assert d.index not in seen, f"Duplicate index {d.index} found" + seen.add(d.index) + assert isinstance(d.relevance_score, float), f"Score must be float, got {type(d.relevance_score)}" + assert d.relevance_score <= last_score, f"Scores not in descending order: {d.relevance_score} > {last_score}" + last_score = d.relevance_score + + +@pytest.mark.parametrize( + "query,items", + [ + (DUMMY_STRING, [DUMMY_STRING, DUMMY_STRING2]), + (DUMMY_TEXT, [DUMMY_TEXT, DUMMY_TEXT2]), + (DUMMY_STRING, [DUMMY_STRING2, DUMMY_TEXT]), + (DUMMY_TEXT, [DUMMY_STRING, DUMMY_TEXT2]), + ], + ids=[ + "string-query-string-items", + "text-query-text-items", + "mixed-content-1", + "mixed-content-2", + ], +) +def test_rerank_text(llama_stack_client, rerank_model_id, query, items, inference_provider_type): + if inference_provider_type not in SUPPORTED_PROVIDERS: + pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ") + + response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items) + assert isinstance(response, RerankResponse) + assert len(response.data) <= len(items) + _validate_rerank_response(response, items) + + +@pytest.mark.parametrize( + "query,items", + [ + (DUMMY_IMAGE_URL, [DUMMY_STRING]), + (DUMMY_IMAGE_BASE64, [DUMMY_TEXT]), + (DUMMY_TEXT, [DUMMY_IMAGE_URL]), + (DUMMY_IMAGE_BASE64, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]), + (DUMMY_TEXT, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]), + ], + ids=[ + "image-query-url", + "image-query-base64", + "text-query-image-item", + "mixed-content-1", + "mixed-content-2", + ], +) +def test_rerank_image(llama_stack_client, rerank_model_id, query, items, inference_provider_type): + if inference_provider_type not in SUPPORTED_PROVIDERS: + pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ") + + if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA: + error_type = ( + ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError + ) + with pytest.raises(error_type): + llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items) + else: + response = llama_stack_client.inference.rerank(model=rerank_model_id, query=query, items=items) + + assert isinstance(response, RerankResponse) + assert len(response.data) <= len(items) + _validate_rerank_response(response, items) + + +def test_rerank_max_results(llama_stack_client, rerank_model_id, inference_provider_type): + if inference_provider_type not in SUPPORTED_PROVIDERS: + pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ") + + items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2] + max_num_results = 2 + + response = llama_stack_client.inference.rerank( + model=rerank_model_id, + query=DUMMY_STRING, + items=items, + max_num_results=max_num_results, + ) + + assert isinstance(response, RerankResponse) + assert len(response.data) == max_num_results + _validate_rerank_response(response, items) + + +def test_rerank_max_results_larger_than_items(llama_stack_client, rerank_model_id, inference_provider_type): + if inference_provider_type not in SUPPORTED_PROVIDERS: + pytest.xfail(f"{inference_provider_type} doesn't support rerank yet") + + items = [DUMMY_STRING, DUMMY_STRING2] + response = llama_stack_client.inference.rerank( + model=rerank_model_id, + query=DUMMY_STRING, + items=items, + max_num_results=10, # Larger than items length + ) + + assert isinstance(response, RerankResponse) + assert len(response.data) <= len(items) # Should return at most len(items) diff --git a/tests/unit/providers/nvidia/test_rerank_inference.py b/tests/unit/providers/nvidia/test_rerank_inference.py new file mode 100644 index 000000000..03c54a732 --- /dev/null +++ b/tests/unit/providers/nvidia/test_rerank_inference.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, patch + +import aiohttp +import pytest + +from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig +from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter + + +class MockResponse: + def __init__(self, status=200, json_data=None, text_data="OK"): + self.status = status + self._json_data = json_data or {"rankings": []} + self._text_data = text_data + + async def json(self): + return self._json_data + + async def text(self): + return self._text_data + + +class MockSession: + def __init__(self, response): + self.response = response + self.post_calls = [] + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return False + + def post(self, url, **kwargs): + self.post_calls.append((url, kwargs)) + + class PostContext: + def __init__(self, response): + self.response = response + + async def __aenter__(self): + return self.response + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return False + + return PostContext(self.response) + + +def create_adapter(config=None, model_metadata=None): + if config is None: + config = NVIDIAConfig(api_key="test-key") + + adapter = NVIDIAInferenceAdapter(config) + + class MockModel: + provider_resource_id = "test-model" + metadata = model_metadata or {} + + adapter.model_store = AsyncMock() + adapter.model_store.get_model = AsyncMock(return_value=MockModel()) + + return adapter + + +@pytest.mark.asyncio +async def test_rerank_basic_functionality(): + adapter = create_adapter() + mock_response = MockResponse(json_data={"rankings": [{"index": 0, "logit": 0.5}]}) + mock_session = MockSession(mock_response) + + with patch("aiohttp.ClientSession", return_value=mock_session): + result = await adapter.rerank(model="test-model", query="test query", items=["item1", "item2"]) + + assert len(result.data) == 1 + assert result.data[0].index == 0 + assert result.data[0].relevance_score == 0.5 + + url, kwargs = mock_session.post_calls[0] + payload = kwargs["json"] + assert payload["model"] == "test-model" + assert payload["query"] == {"text": "test query"} + assert payload["passages"] == [{"text": "item1"}, {"text": "item2"}] + + +@pytest.mark.asyncio +async def test_missing_rankings_key(): + adapter = create_adapter() + mock_session = MockSession(MockResponse(json_data={})) + + with patch("aiohttp.ClientSession", return_value=mock_session): + result = await adapter.rerank(model="test-model", query="q", items=["a"]) + + assert len(result.data) == 0 + + +@pytest.mark.asyncio +async def test_hosted_with_endpoint(): + adapter = create_adapter( + config=NVIDIAConfig(api_key="key"), model_metadata={"endpoint": "https://model.endpoint/rerank"} + ) + mock_session = MockSession(MockResponse()) + + with patch("aiohttp.ClientSession", return_value=mock_session): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + url, _ = mock_session.post_calls[0] + assert url == "https://model.endpoint/rerank" + + +@pytest.mark.asyncio +async def test_hosted_without_endpoint(): + adapter = create_adapter( + config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com). + model_metadata={}, # No "endpoint" key + ) + mock_session = MockSession(MockResponse()) + + with patch("aiohttp.ClientSession", return_value=mock_session): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + url, _ = mock_session.post_calls[0] + assert "https://integrate.api.nvidia.com" in url + + +@pytest.mark.asyncio +async def test_self_hosted_ignores_endpoint(): + adapter = create_adapter( + config=NVIDIAConfig(url="http://localhost:8000", api_key=None), + model_metadata={"endpoint": "https://model.endpoint/rerank"}, # This should be ignored. + ) + mock_session = MockSession(MockResponse()) + + with patch("aiohttp.ClientSession", return_value=mock_session): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + url, _ = mock_session.post_calls[0] + assert "http://localhost:8000" in url + assert "model.endpoint/rerank" not in url + + +@pytest.mark.asyncio +async def test_max_num_results(): + adapter = create_adapter() + rankings = [{"index": 0, "logit": 0.8}, {"index": 1, "logit": 0.6}] + mock_session = MockSession(MockResponse(json_data={"rankings": rankings})) + + with patch("aiohttp.ClientSession", return_value=mock_session): + result = await adapter.rerank(model="test-model", query="q", items=["a", "b"], max_num_results=1) + + assert len(result.data) == 1 + assert result.data[0].index == 0 + assert result.data[0].relevance_score == 0.8 + + +@pytest.mark.asyncio +async def test_http_error(): + adapter = create_adapter() + mock_session = MockSession(MockResponse(status=500, text_data="Server Error")) + + with patch("aiohttp.ClientSession", return_value=mock_session): + with pytest.raises(ConnectionError, match="status 500.*Server Error"): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + +@pytest.mark.asyncio +async def test_client_error(): + adapter = create_adapter() + mock_session = AsyncMock() + mock_session.__aenter__.side_effect = aiohttp.ClientError("Network error") + + with patch("aiohttp.ClientSession", return_value=mock_session): + with pytest.raises(ConnectionError, match="Failed to connect.*Network error"): + await adapter.rerank(model="test-model", query="q", items=["a"])