mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat: Add rerank API for NVIDIA Inference Provider (#3329)
# What does this PR do? Add rerank API for NVIDIA Inference Provider. <!-- If resolving an issue, uncomment and update the line below --> Closes #3278 ## Test Plan Unit test: ``` pytest tests/unit/providers/nvidia/test_rerank_inference.py ``` Integration test: ``` pytest -s -v tests/integration/inference/test_rerank.py --stack-config="inference=nvidia" --rerank-model=nvidia/nvidia/nv-rerankqa-mistral-4b-v3 --env NVIDIA_API_KEY="" --env NVIDIA_BASE_URL="https://integrate.api.nvidia.com" ```
This commit is contained in:
parent
c396de57a4
commit
fa7699d2c3
8 changed files with 622 additions and 1 deletions
|
|
@ -20,6 +20,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
|
|||
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
||||
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
||||
| `append_api_version` | `<class 'bool'>` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |
|
||||
| `rerank_model_to_url` | `dict[str, str` | No | `{'nv-rerank-qa-mistral-4b:1': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking', 'nvidia/nv-rerankqa-mistral-4b-v3': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking', 'nvidia/llama-3.2-nv-rerankqa-1b-v2': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking'}` | Mapping of rerank model identifiers to their API endpoints. |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
|||
|
|
@ -181,3 +181,22 @@ vlm_response = client.chat.completions.create(
|
|||
|
||||
print(f"VLM Response: {vlm_response.choices[0].message.content}")
|
||||
```
|
||||
|
||||
### Rerank Example
|
||||
|
||||
The following example shows how to rerank documents using an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
rerank_response = client.alpha.inference.rerank(
|
||||
model="nvidia/nvidia/llama-3.2-nv-rerankqa-1b-v2",
|
||||
query="query",
|
||||
items=[
|
||||
"item_1",
|
||||
"item_2",
|
||||
"item_3",
|
||||
],
|
||||
)
|
||||
|
||||
for i, result in enumerate(rerank_response):
|
||||
print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]")
|
||||
```
|
||||
|
|
@ -28,6 +28,7 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
|||
Attributes:
|
||||
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
|
||||
api_key (str): The access key for the hosted NIM endpoints
|
||||
rerank_model_to_url (dict[str, str]): Mapping of rerank model identifiers to their API endpoints
|
||||
|
||||
There are two ways to access NVIDIA NIMs -
|
||||
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
|
||||
|
|
@ -55,6 +56,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
|||
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
||||
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||
)
|
||||
rerank_model_to_url: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
|
||||
"nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
|
||||
"nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
|
||||
},
|
||||
description="Mapping of rerank model identifiers to their API endpoints. ",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
|
|
|
|||
|
|
@ -5,6 +5,19 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import aiohttp
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
RerankData,
|
||||
RerankResponse,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
|
@ -61,3 +74,101 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
:return: The NVIDIA API base URL
|
||||
"""
|
||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
"""
|
||||
Return both dynamic model IDs and statically configured rerank model IDs.
|
||||
"""
|
||||
dynamic_ids: Iterable[str] = []
|
||||
try:
|
||||
dynamic_ids = await super().list_provider_model_ids()
|
||||
except Exception:
|
||||
# If the dynamic listing fails, proceed with just configured rerank IDs
|
||||
dynamic_ids = []
|
||||
|
||||
configured_rerank_ids = list(self.config.rerank_model_to_url.keys())
|
||||
return list(dict.fromkeys(list(dynamic_ids) + configured_rerank_ids)) # remove duplicates
|
||||
|
||||
def construct_model_from_identifier(self, identifier: str) -> Model:
|
||||
"""
|
||||
Classify rerank models from config; otherwise use the base behavior.
|
||||
"""
|
||||
if identifier in self.config.rerank_model_to_url:
|
||||
return Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=identifier,
|
||||
identifier=identifier,
|
||||
model_type=ModelType.rerank,
|
||||
)
|
||||
return super().construct_model_from_identifier(identifier)
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
provider_model_id = await self._get_provider_model_id(model)
|
||||
|
||||
ranking_url = self.get_base_url()
|
||||
|
||||
if _is_nvidia_hosted(self.config) and provider_model_id in self.config.rerank_model_to_url:
|
||||
ranking_url = self.config.rerank_model_to_url[provider_model_id]
|
||||
|
||||
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
|
||||
|
||||
# Convert query to text format
|
||||
if isinstance(query, str):
|
||||
query_text = query
|
||||
elif isinstance(query, OpenAIChatCompletionContentPartTextParam):
|
||||
query_text = query.text
|
||||
else:
|
||||
raise ValueError("Query must be a string or text content part")
|
||||
|
||||
# Convert items to text format
|
||||
passages = []
|
||||
for item in items:
|
||||
if isinstance(item, str):
|
||||
passages.append({"text": item})
|
||||
elif isinstance(item, OpenAIChatCompletionContentPartTextParam):
|
||||
passages.append({"text": item.text})
|
||||
else:
|
||||
raise ValueError("Items must be strings or text content parts")
|
||||
|
||||
payload = {
|
||||
"model": provider_model_id,
|
||||
"query": {"text": query_text},
|
||||
"passages": passages,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.get_api_key()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(ranking_url, headers=headers, json=payload) as response:
|
||||
if response.status != 200:
|
||||
response_text = await response.text()
|
||||
raise ConnectionError(
|
||||
f"NVIDIA rerank API request failed with status {response.status}: {response_text}"
|
||||
)
|
||||
|
||||
result = await response.json()
|
||||
rankings = result.get("rankings", [])
|
||||
|
||||
# Convert to RerankData format
|
||||
rerank_data = []
|
||||
for ranking in rankings:
|
||||
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
|
||||
|
||||
# Apply max_num_results limit
|
||||
if max_num_results is not None:
|
||||
rerank_data = rerank_data[:max_num_results]
|
||||
|
||||
return RerankResponse(data=rerank_data)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e
|
||||
|
|
|
|||
|
|
@ -171,6 +171,10 @@ def pytest_addoption(parser):
|
|||
"--embedding-model",
|
||||
help="comma-separated list of embedding models. Fixture name: embedding_model_id",
|
||||
)
|
||||
parser.addoption(
|
||||
"--rerank-model",
|
||||
help="comma-separated list of rerank models. Fixture name: rerank_model_id",
|
||||
)
|
||||
parser.addoption(
|
||||
"--safety-shield",
|
||||
help="comma-separated list of safety shields. Fixture name: shield_id",
|
||||
|
|
@ -249,6 +253,7 @@ def pytest_generate_tests(metafunc):
|
|||
"shield_id": ("--safety-shield", "shield"),
|
||||
"judge_model_id": ("--judge-model", "judge"),
|
||||
"embedding_dimension": ("--embedding-dimension", "dim"),
|
||||
"rerank_model_id": ("--rerank-model", "rerank"),
|
||||
}
|
||||
|
||||
# Collect all parameters and their values
|
||||
|
|
|
|||
|
|
@ -153,6 +153,7 @@ def client_with_models(
|
|||
vision_model_id,
|
||||
embedding_model_id,
|
||||
judge_model_id,
|
||||
rerank_model_id,
|
||||
):
|
||||
client = llama_stack_client
|
||||
|
||||
|
|
@ -170,6 +171,9 @@ def client_with_models(
|
|||
|
||||
if embedding_model_id and embedding_model_id not in model_ids:
|
||||
raise ValueError(f"embedding_model_id {embedding_model_id} not found")
|
||||
|
||||
if rerank_model_id and rerank_model_id not in model_ids:
|
||||
raise ValueError(f"rerank_model_id {rerank_model_id} not found")
|
||||
return client
|
||||
|
||||
|
||||
|
|
@ -185,7 +189,14 @@ def model_providers(llama_stack_client):
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_if_no_model(request):
|
||||
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"]
|
||||
model_fixtures = [
|
||||
"text_model_id",
|
||||
"vision_model_id",
|
||||
"embedding_model_id",
|
||||
"judge_model_id",
|
||||
"shield_id",
|
||||
"rerank_model_id",
|
||||
]
|
||||
test_func = request.node.function
|
||||
|
||||
actual_params = inspect.signature(test_func).parameters.keys()
|
||||
|
|
|
|||
214
tests/integration/inference/test_rerank.py
Normal file
214
tests/integration/inference/test_rerank.py
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
||||
from llama_stack_client.types.alpha import InferenceRerankResponse
|
||||
from llama_stack_client.types.shared.interleaved_content import (
|
||||
ImageContentItem,
|
||||
ImageContentItemImage,
|
||||
ImageContentItemImageURL,
|
||||
TextContentItem,
|
||||
)
|
||||
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
# Test data
|
||||
DUMMY_STRING = "string_1"
|
||||
DUMMY_STRING2 = "string_2"
|
||||
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
|
||||
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
|
||||
DUMMY_IMAGE_URL = ImageContentItem(
|
||||
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
|
||||
)
|
||||
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
|
||||
|
||||
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
|
||||
|
||||
|
||||
def skip_if_provider_doesnt_support_rerank(inference_provider_type):
|
||||
supported_providers = {"remote::nvidia"}
|
||||
if inference_provider_type not in supported_providers:
|
||||
pytest.skip(f"{inference_provider_type} doesn't support rerank models")
|
||||
|
||||
|
||||
def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None:
|
||||
"""
|
||||
Validate that a rerank response has the correct structure and ordering.
|
||||
|
||||
Args:
|
||||
response: The InferenceRerankResponse to validate
|
||||
items: The original items list that was ranked
|
||||
|
||||
Raises:
|
||||
AssertionError: If any validation fails
|
||||
"""
|
||||
seen = set()
|
||||
last_score = float("inf")
|
||||
for d in response:
|
||||
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
|
||||
assert d.index not in seen, f"Duplicate index {d.index} found"
|
||||
seen.add(d.index)
|
||||
assert isinstance(d.relevance_score, float), f"Score must be float, got {type(d.relevance_score)}"
|
||||
assert d.relevance_score <= last_score, f"Scores not in descending order: {d.relevance_score} > {last_score}"
|
||||
last_score = d.relevance_score
|
||||
|
||||
|
||||
def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, expected_first_item: str) -> None:
|
||||
"""
|
||||
Validate that the expected most relevant item ranks first.
|
||||
|
||||
Args:
|
||||
response: The InferenceRerankResponse to validate
|
||||
items: The original items list that was ranked
|
||||
expected_first_item: The expected first item in the ranking
|
||||
|
||||
Raises:
|
||||
AssertionError: If any validation fails
|
||||
"""
|
||||
if not response:
|
||||
raise AssertionError("No ranking data returned in response")
|
||||
|
||||
actual_first_index = response[0].index
|
||||
actual_first_item = items[actual_first_index]
|
||||
assert actual_first_item == expected_first_item, (
|
||||
f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,items",
|
||||
[
|
||||
(DUMMY_STRING, [DUMMY_STRING, DUMMY_STRING2]),
|
||||
(DUMMY_TEXT, [DUMMY_TEXT, DUMMY_TEXT2]),
|
||||
(DUMMY_STRING, [DUMMY_STRING2, DUMMY_TEXT]),
|
||||
(DUMMY_TEXT, [DUMMY_STRING, DUMMY_TEXT2]),
|
||||
],
|
||||
ids=[
|
||||
"string-query-string-items",
|
||||
"text-query-text-items",
|
||||
"mixed-content-1",
|
||||
"mixed-content-2",
|
||||
],
|
||||
)
|
||||
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
assert isinstance(response, list)
|
||||
# TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client.
|
||||
assert len(response) <= len(items)
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,items",
|
||||
[
|
||||
(DUMMY_IMAGE_URL, [DUMMY_STRING]),
|
||||
(DUMMY_IMAGE_BASE64, [DUMMY_TEXT]),
|
||||
(DUMMY_TEXT, [DUMMY_IMAGE_URL]),
|
||||
(DUMMY_IMAGE_BASE64, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
|
||||
(DUMMY_TEXT, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
|
||||
],
|
||||
ids=[
|
||||
"image-query-url",
|
||||
"image-query-base64",
|
||||
"text-query-image-item",
|
||||
"mixed-content-1",
|
||||
"mixed-content-2",
|
||||
],
|
||||
)
|
||||
def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
|
||||
error_type = (
|
||||
ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
|
||||
)
|
||||
with pytest.raises(error_type):
|
||||
client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
else:
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) <= len(items)
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
|
||||
max_num_results = 2
|
||||
|
||||
response = client_with_models.alpha.inference.rerank(
|
||||
model=rerank_model_id,
|
||||
query=DUMMY_STRING,
|
||||
items=items,
|
||||
max_num_results=max_num_results,
|
||||
)
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == max_num_results
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
items = [DUMMY_STRING, DUMMY_STRING2]
|
||||
response = client_with_models.alpha.inference.rerank(
|
||||
model=rerank_model_id,
|
||||
query=DUMMY_STRING,
|
||||
items=items,
|
||||
max_num_results=10, # Larger than items length
|
||||
)
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) <= len(items) # Should return at most len(items)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,items,expected_first_item",
|
||||
[
|
||||
(
|
||||
"What is a reranking model? ",
|
||||
[
|
||||
"A reranking model reranks a list of items based on the query. ",
|
||||
"Machine learning algorithms learn patterns from data. ",
|
||||
"Python is a programming language. ",
|
||||
],
|
||||
"A reranking model reranks a list of items based on the query. ",
|
||||
),
|
||||
(
|
||||
"What is C++?",
|
||||
[
|
||||
"Learning new things is interesting. ",
|
||||
"C++ is a programming language. ",
|
||||
"Books provide knowledge and entertainment. ",
|
||||
],
|
||||
"C++ is a programming language. ",
|
||||
),
|
||||
(
|
||||
"What are good learning habits? ",
|
||||
[
|
||||
"Cooking pasta is a fun activity. ",
|
||||
"Plants need water and sunlight. ",
|
||||
"Good learning habits include reading daily and taking notes. ",
|
||||
],
|
||||
"Good learning habits include reading daily and taking notes. ",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rerank_semantic_correctness(
|
||||
client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type
|
||||
):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
|
||||
_validate_rerank_response(response, items)
|
||||
_validate_semantic_ranking(response, items, expected_first_item)
|
||||
251
tests/unit/providers/nvidia/test_rerank_inference.py
Normal file
251
tests/unit/providers/nvidia/test_rerank_inference.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status=200, json_data=None, text_data="OK"):
|
||||
self.status = status
|
||||
self._json_data = json_data or {"rankings": []}
|
||||
self._text_data = text_data
|
||||
|
||||
async def json(self):
|
||||
return self._json_data
|
||||
|
||||
async def text(self):
|
||||
return self._text_data
|
||||
|
||||
|
||||
class MockSession:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
self.post_calls = []
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
self.post_calls.append((url, kwargs))
|
||||
|
||||
class PostContext:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.response
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
return PostContext(self.response)
|
||||
|
||||
|
||||
def create_adapter(config=None, rerank_endpoints=None):
|
||||
if config is None:
|
||||
config = NVIDIAConfig(api_key="test-key")
|
||||
|
||||
adapter = NVIDIAInferenceAdapter(config=config)
|
||||
|
||||
class MockModel:
|
||||
provider_resource_id = "test-model"
|
||||
metadata = {}
|
||||
|
||||
adapter.model_store = AsyncMock()
|
||||
adapter.model_store.get_model = AsyncMock(return_value=MockModel())
|
||||
|
||||
if rerank_endpoints is not None:
|
||||
adapter.config.rerank_model_to_url = rerank_endpoints
|
||||
|
||||
return adapter
|
||||
|
||||
|
||||
async def test_rerank_basic_functionality():
|
||||
adapter = create_adapter()
|
||||
mock_response = MockResponse(json_data={"rankings": [{"index": 0, "logit": 0.5}]})
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await adapter.rerank(model="test-model", query="test query", items=["item1", "item2"])
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].index == 0
|
||||
assert result.data[0].relevance_score == 0.5
|
||||
|
||||
url, kwargs = mock_session.post_calls[0]
|
||||
payload = kwargs["json"]
|
||||
assert payload["model"] == "test-model"
|
||||
assert payload["query"] == {"text": "test query"}
|
||||
assert payload["passages"] == [{"text": "item1"}, {"text": "item2"}]
|
||||
|
||||
|
||||
async def test_missing_rankings_key():
|
||||
adapter = create_adapter()
|
||||
mock_session = MockSession(MockResponse(json_data={}))
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
assert len(result.data) == 0
|
||||
|
||||
|
||||
async def test_hosted_with_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"test-model": "https://model.endpoint/rerank"}
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert url == "https://model.endpoint/rerank"
|
||||
|
||||
|
||||
async def test_hosted_without_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com).
|
||||
rerank_endpoints={}, # No endpoint mapping for test-model
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert "https://integrate.api.nvidia.com" in url
|
||||
|
||||
|
||||
async def test_hosted_model_not_in_endpoint_mapping():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"other-model": "https://other.endpoint/rerank"}
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert "https://integrate.api.nvidia.com" in url
|
||||
assert url != "https://other.endpoint/rerank"
|
||||
|
||||
|
||||
async def test_self_hosted_ignores_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
|
||||
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert "http://localhost:8000" in url
|
||||
assert "model.endpoint/rerank" not in url
|
||||
|
||||
|
||||
async def test_max_num_results():
|
||||
adapter = create_adapter()
|
||||
rankings = [{"index": 0, "logit": 0.8}, {"index": 1, "logit": 0.6}]
|
||||
mock_session = MockSession(MockResponse(json_data={"rankings": rankings}))
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await adapter.rerank(model="test-model", query="q", items=["a", "b"], max_num_results=1)
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].index == 0
|
||||
assert result.data[0].relevance_score == 0.8
|
||||
|
||||
|
||||
async def test_http_error():
|
||||
adapter = create_adapter()
|
||||
mock_session = MockSession(MockResponse(status=500, text_data="Server Error"))
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(ConnectionError, match="status 500.*Server Error"):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
|
||||
async def test_client_error():
|
||||
adapter = create_adapter()
|
||||
mock_session = AsyncMock()
|
||||
mock_session.__aenter__.side_effect = aiohttp.ClientError("Network error")
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(ConnectionError, match="Failed to connect.*Network error"):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
|
||||
async def test_list_models_includes_configured_rerank_models():
|
||||
"""Test that list_models adds rerank models to the dynamic model list."""
|
||||
adapter = create_adapter()
|
||||
adapter.__provider_id__ = "nvidia"
|
||||
adapter.__provider_spec__ = MagicMock()
|
||||
|
||||
dynamic_ids = ["llm-1", "embedding-1"]
|
||||
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(return_value=dynamic_ids)):
|
||||
result = await adapter.list_models()
|
||||
|
||||
assert result is not None
|
||||
|
||||
# Check that the rerank models are added
|
||||
model_ids = [m.identifier for m in result]
|
||||
assert "nv-rerank-qa-mistral-4b:1" in model_ids
|
||||
assert "nvidia/nv-rerankqa-mistral-4b-v3" in model_ids
|
||||
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in model_ids
|
||||
|
||||
rerank_models = [m for m in result if m.model_type == ModelType.rerank]
|
||||
|
||||
assert len(rerank_models) == 3
|
||||
|
||||
for m in rerank_models:
|
||||
assert m.provider_id == "nvidia"
|
||||
assert m.model_type == ModelType.rerank
|
||||
assert m.metadata == {}
|
||||
assert m.identifier in adapter._model_cache
|
||||
|
||||
|
||||
async def test_list_provider_model_ids_has_no_duplicates():
|
||||
adapter = create_adapter()
|
||||
|
||||
dynamic_ids = [
|
||||
"llm-1",
|
||||
"nvidia/nv-rerankqa-mistral-4b-v3", # overlaps configured rerank ids
|
||||
"embedding-1",
|
||||
"llm-1",
|
||||
]
|
||||
|
||||
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(return_value=dynamic_ids)):
|
||||
ids = list(await adapter.list_provider_model_ids())
|
||||
|
||||
assert len(ids) == len(set(ids))
|
||||
assert ids.count("nvidia/nv-rerankqa-mistral-4b-v3") == 1
|
||||
assert "nv-rerank-qa-mistral-4b:1" in ids
|
||||
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in ids
|
||||
|
||||
|
||||
async def test_list_provider_model_ids_uses_configured_on_dynamic_failure():
|
||||
adapter = create_adapter()
|
||||
|
||||
# Simulate dynamic listing failure
|
||||
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(side_effect=Exception)):
|
||||
ids = list(await adapter.list_provider_model_ids())
|
||||
|
||||
# Should still return configured rerank ids
|
||||
configured_ids = list(adapter.config.rerank_model_to_url.keys())
|
||||
assert set(ids) == set(configured_ids)
|
||||
Loading…
Add table
Add a link
Reference in a new issue