From fa7699d2c3db55f214a794be8139789174e09cb0 Mon Sep 17 00:00:00 2001 From: Jiayi Ni Date: Thu, 30 Oct 2025 21:42:09 -0700 Subject: [PATCH] feat: Add rerank API for NVIDIA Inference Provider (#3329) # What does this PR do? Add rerank API for NVIDIA Inference Provider. Closes #3278 ## Test Plan Unit test: ``` pytest tests/unit/providers/nvidia/test_rerank_inference.py ``` Integration test: ``` pytest -s -v tests/integration/inference/test_rerank.py --stack-config="inference=nvidia" --rerank-model=nvidia/nvidia/nv-rerankqa-mistral-4b-v3 --env NVIDIA_API_KEY="" --env NVIDIA_BASE_URL="https://integrate.api.nvidia.com" ``` --- .../providers/inference/remote_nvidia.mdx | 1 + .../remote/inference/nvidia/NVIDIA.md | 19 ++ .../remote/inference/nvidia/config.py | 9 + .../remote/inference/nvidia/nvidia.py | 111 ++++++++ tests/integration/conftest.py | 5 + tests/integration/fixtures/common.py | 13 +- tests/integration/inference/test_rerank.py | 214 +++++++++++++++ .../providers/nvidia/test_rerank_inference.py | 251 ++++++++++++++++++ 8 files changed, 622 insertions(+), 1 deletion(-) create mode 100644 tests/integration/inference/test_rerank.py create mode 100644 tests/unit/providers/nvidia/test_rerank_inference.py diff --git a/docs/docs/providers/inference/remote_nvidia.mdx b/docs/docs/providers/inference/remote_nvidia.mdx index b4e04176c..57c64ab46 100644 --- a/docs/docs/providers/inference/remote_nvidia.mdx +++ b/docs/docs/providers/inference/remote_nvidia.mdx @@ -20,6 +20,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services. | `url` | `` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM | | `timeout` | `` | No | 60 | Timeout for the HTTP requests | | `append_api_version` | `` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. | +| `rerank_model_to_url` | `dict[str, str` | No | `{'nv-rerank-qa-mistral-4b:1': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking', 'nvidia/nv-rerankqa-mistral-4b-v3': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking', 'nvidia/llama-3.2-nv-rerankqa-1b-v2': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking'}` | Mapping of rerank model identifiers to their API endpoints. | ## Sample Configuration diff --git a/src/llama_stack/providers/remote/inference/nvidia/NVIDIA.md b/src/llama_stack/providers/remote/inference/nvidia/NVIDIA.md index f1a828413..97fa95a1f 100644 --- a/src/llama_stack/providers/remote/inference/nvidia/NVIDIA.md +++ b/src/llama_stack/providers/remote/inference/nvidia/NVIDIA.md @@ -181,3 +181,22 @@ vlm_response = client.chat.completions.create( print(f"VLM Response: {vlm_response.choices[0].message.content}") ``` + +### Rerank Example + +The following example shows how to rerank documents using an NVIDIA NIM. + +```python +rerank_response = client.alpha.inference.rerank( + model="nvidia/nvidia/llama-3.2-nv-rerankqa-1b-v2", + query="query", + items=[ + "item_1", + "item_2", + "item_3", + ], +) + +for i, result in enumerate(rerank_response): + print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]") +``` \ No newline at end of file diff --git a/src/llama_stack/providers/remote/inference/nvidia/config.py b/src/llama_stack/providers/remote/inference/nvidia/config.py index 3545d2b11..618bbe078 100644 --- a/src/llama_stack/providers/remote/inference/nvidia/config.py +++ b/src/llama_stack/providers/remote/inference/nvidia/config.py @@ -28,6 +28,7 @@ class NVIDIAConfig(RemoteInferenceProviderConfig): Attributes: url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000 api_key (str): The access key for the hosted NIM endpoints + rerank_model_to_url (dict[str, str]): Mapping of rerank model identifiers to their API endpoints There are two ways to access NVIDIA NIMs - 0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com @@ -55,6 +56,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig): default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false", description="When set to false, the API version will not be appended to the base_url. By default, it is true.", ) + rerank_model_to_url: dict[str, str] = Field( + default_factory=lambda: { + "nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking", + "nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking", + "nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking", + }, + description="Mapping of rerank model identifiers to their API endpoints. ", + ) @classmethod def sample_run_config( diff --git a/src/llama_stack/providers/remote/inference/nvidia/nvidia.py b/src/llama_stack/providers/remote/inference/nvidia/nvidia.py index ea11b49cd..bc5aa7953 100644 --- a/src/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/src/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -5,6 +5,19 @@ # the root directory of this source tree. +from collections.abc import Iterable + +import aiohttp + +from llama_stack.apis.inference import ( + RerankData, + RerankResponse, +) +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, +) +from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin @@ -61,3 +74,101 @@ class NVIDIAInferenceAdapter(OpenAIMixin): :return: The NVIDIA API base URL """ return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url + + async def list_provider_model_ids(self) -> Iterable[str]: + """ + Return both dynamic model IDs and statically configured rerank model IDs. + """ + dynamic_ids: Iterable[str] = [] + try: + dynamic_ids = await super().list_provider_model_ids() + except Exception: + # If the dynamic listing fails, proceed with just configured rerank IDs + dynamic_ids = [] + + configured_rerank_ids = list(self.config.rerank_model_to_url.keys()) + return list(dict.fromkeys(list(dynamic_ids) + configured_rerank_ids)) # remove duplicates + + def construct_model_from_identifier(self, identifier: str) -> Model: + """ + Classify rerank models from config; otherwise use the base behavior. + """ + if identifier in self.config.rerank_model_to_url: + return Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=identifier, + identifier=identifier, + model_type=ModelType.rerank, + ) + return super().construct_model_from_identifier(identifier) + + async def rerank( + self, + model: str, + query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + max_num_results: int | None = None, + ) -> RerankResponse: + provider_model_id = await self._get_provider_model_id(model) + + ranking_url = self.get_base_url() + + if _is_nvidia_hosted(self.config) and provider_model_id in self.config.rerank_model_to_url: + ranking_url = self.config.rerank_model_to_url[provider_model_id] + + logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}") + + # Convert query to text format + if isinstance(query, str): + query_text = query + elif isinstance(query, OpenAIChatCompletionContentPartTextParam): + query_text = query.text + else: + raise ValueError("Query must be a string or text content part") + + # Convert items to text format + passages = [] + for item in items: + if isinstance(item, str): + passages.append({"text": item}) + elif isinstance(item, OpenAIChatCompletionContentPartTextParam): + passages.append({"text": item.text}) + else: + raise ValueError("Items must be strings or text content parts") + + payload = { + "model": provider_model_id, + "query": {"text": query_text}, + "passages": passages, + } + + headers = { + "Authorization": f"Bearer {self.get_api_key()}", + "Content-Type": "application/json", + } + + try: + async with aiohttp.ClientSession() as session: + async with session.post(ranking_url, headers=headers, json=payload) as response: + if response.status != 200: + response_text = await response.text() + raise ConnectionError( + f"NVIDIA rerank API request failed with status {response.status}: {response_text}" + ) + + result = await response.json() + rankings = result.get("rankings", []) + + # Convert to RerankData format + rerank_data = [] + for ranking in rankings: + rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"])) + + # Apply max_num_results limit + if max_num_results is not None: + rerank_data = rerank_data[:max_num_results] + + return RerankResponse(data=rerank_data) + + except aiohttp.ClientError as e: + raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index aaedd8476..e5ae72fc1 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -171,6 +171,10 @@ def pytest_addoption(parser): "--embedding-model", help="comma-separated list of embedding models. Fixture name: embedding_model_id", ) + parser.addoption( + "--rerank-model", + help="comma-separated list of rerank models. Fixture name: rerank_model_id", + ) parser.addoption( "--safety-shield", help="comma-separated list of safety shields. Fixture name: shield_id", @@ -249,6 +253,7 @@ def pytest_generate_tests(metafunc): "shield_id": ("--safety-shield", "shield"), "judge_model_id": ("--judge-model", "judge"), "embedding_dimension": ("--embedding-dimension", "dim"), + "rerank_model_id": ("--rerank-model", "rerank"), } # Collect all parameters and their values diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index e68f9dc9e..57775ce25 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -153,6 +153,7 @@ def client_with_models( vision_model_id, embedding_model_id, judge_model_id, + rerank_model_id, ): client = llama_stack_client @@ -170,6 +171,9 @@ def client_with_models( if embedding_model_id and embedding_model_id not in model_ids: raise ValueError(f"embedding_model_id {embedding_model_id} not found") + + if rerank_model_id and rerank_model_id not in model_ids: + raise ValueError(f"rerank_model_id {rerank_model_id} not found") return client @@ -185,7 +189,14 @@ def model_providers(llama_stack_client): @pytest.fixture(autouse=True) def skip_if_no_model(request): - model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"] + model_fixtures = [ + "text_model_id", + "vision_model_id", + "embedding_model_id", + "judge_model_id", + "shield_id", + "rerank_model_id", + ] test_func = request.node.function actual_params = inspect.signature(test_func).parameters.keys() diff --git a/tests/integration/inference/test_rerank.py b/tests/integration/inference/test_rerank.py new file mode 100644 index 000000000..82f35cd27 --- /dev/null +++ b/tests/integration/inference/test_rerank.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +from llama_stack_client import BadRequestError as LlamaStackBadRequestError +from llama_stack_client.types.alpha import InferenceRerankResponse +from llama_stack_client.types.shared.interleaved_content import ( + ImageContentItem, + ImageContentItemImage, + ImageContentItemImageURL, + TextContentItem, +) + +from llama_stack.core.library_client import LlamaStackAsLibraryClient + +# Test data +DUMMY_STRING = "string_1" +DUMMY_STRING2 = "string_2" +DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text") +DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text") +DUMMY_IMAGE_URL = ImageContentItem( + image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image" +) +DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image") + +PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models + + +def skip_if_provider_doesnt_support_rerank(inference_provider_type): + supported_providers = {"remote::nvidia"} + if inference_provider_type not in supported_providers: + pytest.skip(f"{inference_provider_type} doesn't support rerank models") + + +def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None: + """ + Validate that a rerank response has the correct structure and ordering. + + Args: + response: The InferenceRerankResponse to validate + items: The original items list that was ranked + + Raises: + AssertionError: If any validation fails + """ + seen = set() + last_score = float("inf") + for d in response: + assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items" + assert d.index not in seen, f"Duplicate index {d.index} found" + seen.add(d.index) + assert isinstance(d.relevance_score, float), f"Score must be float, got {type(d.relevance_score)}" + assert d.relevance_score <= last_score, f"Scores not in descending order: {d.relevance_score} > {last_score}" + last_score = d.relevance_score + + +def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, expected_first_item: str) -> None: + """ + Validate that the expected most relevant item ranks first. + + Args: + response: The InferenceRerankResponse to validate + items: The original items list that was ranked + expected_first_item: The expected first item in the ranking + + Raises: + AssertionError: If any validation fails + """ + if not response: + raise AssertionError("No ranking data returned in response") + + actual_first_index = response[0].index + actual_first_item = items[actual_first_index] + assert actual_first_item == expected_first_item, ( + f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead." + ) + + +@pytest.mark.parametrize( + "query,items", + [ + (DUMMY_STRING, [DUMMY_STRING, DUMMY_STRING2]), + (DUMMY_TEXT, [DUMMY_TEXT, DUMMY_TEXT2]), + (DUMMY_STRING, [DUMMY_STRING2, DUMMY_TEXT]), + (DUMMY_TEXT, [DUMMY_STRING, DUMMY_TEXT2]), + ], + ids=[ + "string-query-string-items", + "text-query-text-items", + "mixed-content-1", + "mixed-content-2", + ], +) +def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type): + skip_if_provider_doesnt_support_rerank(inference_provider_type) + + response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items) + assert isinstance(response, list) + # TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client. + assert len(response) <= len(items) + _validate_rerank_response(response, items) + + +@pytest.mark.parametrize( + "query,items", + [ + (DUMMY_IMAGE_URL, [DUMMY_STRING]), + (DUMMY_IMAGE_BASE64, [DUMMY_TEXT]), + (DUMMY_TEXT, [DUMMY_IMAGE_URL]), + (DUMMY_IMAGE_BASE64, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]), + (DUMMY_TEXT, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]), + ], + ids=[ + "image-query-url", + "image-query-base64", + "text-query-image-item", + "mixed-content-1", + "mixed-content-2", + ], +) +def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type): + skip_if_provider_doesnt_support_rerank(inference_provider_type) + + if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA: + error_type = ( + ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError + ) + with pytest.raises(error_type): + client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items) + else: + response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items) + + assert isinstance(response, list) + assert len(response) <= len(items) + _validate_rerank_response(response, items) + + +def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type): + skip_if_provider_doesnt_support_rerank(inference_provider_type) + + items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2] + max_num_results = 2 + + response = client_with_models.alpha.inference.rerank( + model=rerank_model_id, + query=DUMMY_STRING, + items=items, + max_num_results=max_num_results, + ) + + assert isinstance(response, list) + assert len(response) == max_num_results + _validate_rerank_response(response, items) + + +def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type): + skip_if_provider_doesnt_support_rerank(inference_provider_type) + + items = [DUMMY_STRING, DUMMY_STRING2] + response = client_with_models.alpha.inference.rerank( + model=rerank_model_id, + query=DUMMY_STRING, + items=items, + max_num_results=10, # Larger than items length + ) + + assert isinstance(response, list) + assert len(response) <= len(items) # Should return at most len(items) + + +@pytest.mark.parametrize( + "query,items,expected_first_item", + [ + ( + "What is a reranking model? ", + [ + "A reranking model reranks a list of items based on the query. ", + "Machine learning algorithms learn patterns from data. ", + "Python is a programming language. ", + ], + "A reranking model reranks a list of items based on the query. ", + ), + ( + "What is C++?", + [ + "Learning new things is interesting. ", + "C++ is a programming language. ", + "Books provide knowledge and entertainment. ", + ], + "C++ is a programming language. ", + ), + ( + "What are good learning habits? ", + [ + "Cooking pasta is a fun activity. ", + "Plants need water and sunlight. ", + "Good learning habits include reading daily and taking notes. ", + ], + "Good learning habits include reading daily and taking notes. ", + ), + ], +) +def test_rerank_semantic_correctness( + client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type +): + skip_if_provider_doesnt_support_rerank(inference_provider_type) + + response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items) + + _validate_rerank_response(response, items) + _validate_semantic_ranking(response, items, expected_first_item) diff --git a/tests/unit/providers/nvidia/test_rerank_inference.py b/tests/unit/providers/nvidia/test_rerank_inference.py new file mode 100644 index 000000000..2793b5f44 --- /dev/null +++ b/tests/unit/providers/nvidia/test_rerank_inference.py @@ -0,0 +1,251 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, MagicMock, patch + +import aiohttp +import pytest + +from llama_stack.apis.models import ModelType +from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig +from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin + + +class MockResponse: + def __init__(self, status=200, json_data=None, text_data="OK"): + self.status = status + self._json_data = json_data or {"rankings": []} + self._text_data = text_data + + async def json(self): + return self._json_data + + async def text(self): + return self._text_data + + +class MockSession: + def __init__(self, response): + self.response = response + self.post_calls = [] + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return False + + def post(self, url, **kwargs): + self.post_calls.append((url, kwargs)) + + class PostContext: + def __init__(self, response): + self.response = response + + async def __aenter__(self): + return self.response + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return False + + return PostContext(self.response) + + +def create_adapter(config=None, rerank_endpoints=None): + if config is None: + config = NVIDIAConfig(api_key="test-key") + + adapter = NVIDIAInferenceAdapter(config=config) + + class MockModel: + provider_resource_id = "test-model" + metadata = {} + + adapter.model_store = AsyncMock() + adapter.model_store.get_model = AsyncMock(return_value=MockModel()) + + if rerank_endpoints is not None: + adapter.config.rerank_model_to_url = rerank_endpoints + + return adapter + + +async def test_rerank_basic_functionality(): + adapter = create_adapter() + mock_response = MockResponse(json_data={"rankings": [{"index": 0, "logit": 0.5}]}) + mock_session = MockSession(mock_response) + + with patch("aiohttp.ClientSession", return_value=mock_session): + result = await adapter.rerank(model="test-model", query="test query", items=["item1", "item2"]) + + assert len(result.data) == 1 + assert result.data[0].index == 0 + assert result.data[0].relevance_score == 0.5 + + url, kwargs = mock_session.post_calls[0] + payload = kwargs["json"] + assert payload["model"] == "test-model" + assert payload["query"] == {"text": "test query"} + assert payload["passages"] == [{"text": "item1"}, {"text": "item2"}] + + +async def test_missing_rankings_key(): + adapter = create_adapter() + mock_session = MockSession(MockResponse(json_data={})) + + with patch("aiohttp.ClientSession", return_value=mock_session): + result = await adapter.rerank(model="test-model", query="q", items=["a"]) + + assert len(result.data) == 0 + + +async def test_hosted_with_endpoint(): + adapter = create_adapter( + config=NVIDIAConfig(api_key="key"), rerank_endpoints={"test-model": "https://model.endpoint/rerank"} + ) + mock_session = MockSession(MockResponse()) + + with patch("aiohttp.ClientSession", return_value=mock_session): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + url, _ = mock_session.post_calls[0] + assert url == "https://model.endpoint/rerank" + + +async def test_hosted_without_endpoint(): + adapter = create_adapter( + config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com). + rerank_endpoints={}, # No endpoint mapping for test-model + ) + mock_session = MockSession(MockResponse()) + + with patch("aiohttp.ClientSession", return_value=mock_session): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + url, _ = mock_session.post_calls[0] + assert "https://integrate.api.nvidia.com" in url + + +async def test_hosted_model_not_in_endpoint_mapping(): + adapter = create_adapter( + config=NVIDIAConfig(api_key="key"), rerank_endpoints={"other-model": "https://other.endpoint/rerank"} + ) + mock_session = MockSession(MockResponse()) + + with patch("aiohttp.ClientSession", return_value=mock_session): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + url, _ = mock_session.post_calls[0] + assert "https://integrate.api.nvidia.com" in url + assert url != "https://other.endpoint/rerank" + + +async def test_self_hosted_ignores_endpoint(): + adapter = create_adapter( + config=NVIDIAConfig(url="http://localhost:8000", api_key=None), + rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted. + ) + mock_session = MockSession(MockResponse()) + + with patch("aiohttp.ClientSession", return_value=mock_session): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + url, _ = mock_session.post_calls[0] + assert "http://localhost:8000" in url + assert "model.endpoint/rerank" not in url + + +async def test_max_num_results(): + adapter = create_adapter() + rankings = [{"index": 0, "logit": 0.8}, {"index": 1, "logit": 0.6}] + mock_session = MockSession(MockResponse(json_data={"rankings": rankings})) + + with patch("aiohttp.ClientSession", return_value=mock_session): + result = await adapter.rerank(model="test-model", query="q", items=["a", "b"], max_num_results=1) + + assert len(result.data) == 1 + assert result.data[0].index == 0 + assert result.data[0].relevance_score == 0.8 + + +async def test_http_error(): + adapter = create_adapter() + mock_session = MockSession(MockResponse(status=500, text_data="Server Error")) + + with patch("aiohttp.ClientSession", return_value=mock_session): + with pytest.raises(ConnectionError, match="status 500.*Server Error"): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + +async def test_client_error(): + adapter = create_adapter() + mock_session = AsyncMock() + mock_session.__aenter__.side_effect = aiohttp.ClientError("Network error") + + with patch("aiohttp.ClientSession", return_value=mock_session): + with pytest.raises(ConnectionError, match="Failed to connect.*Network error"): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + +async def test_list_models_includes_configured_rerank_models(): + """Test that list_models adds rerank models to the dynamic model list.""" + adapter = create_adapter() + adapter.__provider_id__ = "nvidia" + adapter.__provider_spec__ = MagicMock() + + dynamic_ids = ["llm-1", "embedding-1"] + with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(return_value=dynamic_ids)): + result = await adapter.list_models() + + assert result is not None + + # Check that the rerank models are added + model_ids = [m.identifier for m in result] + assert "nv-rerank-qa-mistral-4b:1" in model_ids + assert "nvidia/nv-rerankqa-mistral-4b-v3" in model_ids + assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in model_ids + + rerank_models = [m for m in result if m.model_type == ModelType.rerank] + + assert len(rerank_models) == 3 + + for m in rerank_models: + assert m.provider_id == "nvidia" + assert m.model_type == ModelType.rerank + assert m.metadata == {} + assert m.identifier in adapter._model_cache + + +async def test_list_provider_model_ids_has_no_duplicates(): + adapter = create_adapter() + + dynamic_ids = [ + "llm-1", + "nvidia/nv-rerankqa-mistral-4b-v3", # overlaps configured rerank ids + "embedding-1", + "llm-1", + ] + + with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(return_value=dynamic_ids)): + ids = list(await adapter.list_provider_model_ids()) + + assert len(ids) == len(set(ids)) + assert ids.count("nvidia/nv-rerankqa-mistral-4b-v3") == 1 + assert "nv-rerank-qa-mistral-4b:1" in ids + assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in ids + + +async def test_list_provider_model_ids_uses_configured_on_dynamic_failure(): + adapter = create_adapter() + + # Simulate dynamic listing failure + with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(side_effect=Exception)): + ids = list(await adapter.list_provider_model_ids()) + + # Should still return configured rerank ids + configured_ids = list(adapter.config.rerank_model_to_url.keys()) + assert set(ids) == set(configured_ids)