feat: Add rerank API for NVIDIA Inference Provider (#3329)

# What does this PR do?
Add rerank API for NVIDIA Inference Provider.

<!-- If resolving an issue, uncomment and update the line below -->
Closes #3278 

## Test Plan
Unit test:
```
pytest tests/unit/providers/nvidia/test_rerank_inference.py
```

Integration test: 
```
pytest -s -v tests/integration/inference/test_rerank.py   --stack-config="inference=nvidia"   --rerank-model=nvidia/nvidia/nv-rerankqa-mistral-4b-v3   --env NVIDIA_API_KEY=""   --env NVIDIA_BASE_URL="https://integrate.api.nvidia.com"
```
This commit is contained in:
Jiayi Ni 2025-10-30 21:42:09 -07:00 committed by GitHub
parent c396de57a4
commit fa7699d2c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 622 additions and 1 deletions

View file

@ -171,6 +171,10 @@ def pytest_addoption(parser):
"--embedding-model",
help="comma-separated list of embedding models. Fixture name: embedding_model_id",
)
parser.addoption(
"--rerank-model",
help="comma-separated list of rerank models. Fixture name: rerank_model_id",
)
parser.addoption(
"--safety-shield",
help="comma-separated list of safety shields. Fixture name: shield_id",
@ -249,6 +253,7 @@ def pytest_generate_tests(metafunc):
"shield_id": ("--safety-shield", "shield"),
"judge_model_id": ("--judge-model", "judge"),
"embedding_dimension": ("--embedding-dimension", "dim"),
"rerank_model_id": ("--rerank-model", "rerank"),
}
# Collect all parameters and their values

View file

@ -153,6 +153,7 @@ def client_with_models(
vision_model_id,
embedding_model_id,
judge_model_id,
rerank_model_id,
):
client = llama_stack_client
@ -170,6 +171,9 @@ def client_with_models(
if embedding_model_id and embedding_model_id not in model_ids:
raise ValueError(f"embedding_model_id {embedding_model_id} not found")
if rerank_model_id and rerank_model_id not in model_ids:
raise ValueError(f"rerank_model_id {rerank_model_id} not found")
return client
@ -185,7 +189,14 @@ def model_providers(llama_stack_client):
@pytest.fixture(autouse=True)
def skip_if_no_model(request):
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"]
model_fixtures = [
"text_model_id",
"vision_model_id",
"embedding_model_id",
"judge_model_id",
"shield_id",
"rerank_model_id",
]
test_func = request.node.function
actual_params = inspect.signature(test_func).parameters.keys()

View file

@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
from llama_stack_client.types.alpha import InferenceRerankResponse
from llama_stack_client.types.shared.interleaved_content import (
ImageContentItem,
ImageContentItemImage,
ImageContentItemImageURL,
TextContentItem,
)
from llama_stack.core.library_client import LlamaStackAsLibraryClient
# Test data
DUMMY_STRING = "string_1"
DUMMY_STRING2 = "string_2"
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
DUMMY_IMAGE_URL = ImageContentItem(
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
)
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
def skip_if_provider_doesnt_support_rerank(inference_provider_type):
supported_providers = {"remote::nvidia"}
if inference_provider_type not in supported_providers:
pytest.skip(f"{inference_provider_type} doesn't support rerank models")
def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None:
"""
Validate that a rerank response has the correct structure and ordering.
Args:
response: The InferenceRerankResponse to validate
items: The original items list that was ranked
Raises:
AssertionError: If any validation fails
"""
seen = set()
last_score = float("inf")
for d in response:
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
assert d.index not in seen, f"Duplicate index {d.index} found"
seen.add(d.index)
assert isinstance(d.relevance_score, float), f"Score must be float, got {type(d.relevance_score)}"
assert d.relevance_score <= last_score, f"Scores not in descending order: {d.relevance_score} > {last_score}"
last_score = d.relevance_score
def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, expected_first_item: str) -> None:
"""
Validate that the expected most relevant item ranks first.
Args:
response: The InferenceRerankResponse to validate
items: The original items list that was ranked
expected_first_item: The expected first item in the ranking
Raises:
AssertionError: If any validation fails
"""
if not response:
raise AssertionError("No ranking data returned in response")
actual_first_index = response[0].index
actual_first_item = items[actual_first_index]
assert actual_first_item == expected_first_item, (
f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead."
)
@pytest.mark.parametrize(
"query,items",
[
(DUMMY_STRING, [DUMMY_STRING, DUMMY_STRING2]),
(DUMMY_TEXT, [DUMMY_TEXT, DUMMY_TEXT2]),
(DUMMY_STRING, [DUMMY_STRING2, DUMMY_TEXT]),
(DUMMY_TEXT, [DUMMY_STRING, DUMMY_TEXT2]),
],
ids=[
"string-query-string-items",
"text-query-text-items",
"mixed-content-1",
"mixed-content-2",
],
)
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, list)
# TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client.
assert len(response) <= len(items)
_validate_rerank_response(response, items)
@pytest.mark.parametrize(
"query,items",
[
(DUMMY_IMAGE_URL, [DUMMY_STRING]),
(DUMMY_IMAGE_BASE64, [DUMMY_TEXT]),
(DUMMY_TEXT, [DUMMY_IMAGE_URL]),
(DUMMY_IMAGE_BASE64, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
(DUMMY_TEXT, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
],
ids=[
"image-query-url",
"image-query-base64",
"text-query-image-item",
"mixed-content-1",
"mixed-content-2",
],
)
def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
error_type = (
ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
)
with pytest.raises(error_type):
client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
else:
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, list)
assert len(response) <= len(items)
_validate_rerank_response(response, items)
def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
max_num_results = 2
response = client_with_models.alpha.inference.rerank(
model=rerank_model_id,
query=DUMMY_STRING,
items=items,
max_num_results=max_num_results,
)
assert isinstance(response, list)
assert len(response) == max_num_results
_validate_rerank_response(response, items)
def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
items = [DUMMY_STRING, DUMMY_STRING2]
response = client_with_models.alpha.inference.rerank(
model=rerank_model_id,
query=DUMMY_STRING,
items=items,
max_num_results=10, # Larger than items length
)
assert isinstance(response, list)
assert len(response) <= len(items) # Should return at most len(items)
@pytest.mark.parametrize(
"query,items,expected_first_item",
[
(
"What is a reranking model? ",
[
"A reranking model reranks a list of items based on the query. ",
"Machine learning algorithms learn patterns from data. ",
"Python is a programming language. ",
],
"A reranking model reranks a list of items based on the query. ",
),
(
"What is C++?",
[
"Learning new things is interesting. ",
"C++ is a programming language. ",
"Books provide knowledge and entertainment. ",
],
"C++ is a programming language. ",
),
(
"What are good learning habits? ",
[
"Cooking pasta is a fun activity. ",
"Plants need water and sunlight. ",
"Good learning habits include reading daily and taking notes. ",
],
"Good learning habits include reading daily and taking notes. ",
),
],
)
def test_rerank_semantic_correctness(
client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type
):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
_validate_rerank_response(response, items)
_validate_semantic_ranking(response, items, expected_first_item)

View file

@ -0,0 +1,251 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
import pytest
from llama_stack.apis.models import ModelType
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
class MockResponse:
def __init__(self, status=200, json_data=None, text_data="OK"):
self.status = status
self._json_data = json_data or {"rankings": []}
self._text_data = text_data
async def json(self):
return self._json_data
async def text(self):
return self._text_data
class MockSession:
def __init__(self, response):
self.response = response
self.post_calls = []
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return False
def post(self, url, **kwargs):
self.post_calls.append((url, kwargs))
class PostContext:
def __init__(self, response):
self.response = response
async def __aenter__(self):
return self.response
async def __aexit__(self, exc_type, exc_val, exc_tb):
return False
return PostContext(self.response)
def create_adapter(config=None, rerank_endpoints=None):
if config is None:
config = NVIDIAConfig(api_key="test-key")
adapter = NVIDIAInferenceAdapter(config=config)
class MockModel:
provider_resource_id = "test-model"
metadata = {}
adapter.model_store = AsyncMock()
adapter.model_store.get_model = AsyncMock(return_value=MockModel())
if rerank_endpoints is not None:
adapter.config.rerank_model_to_url = rerank_endpoints
return adapter
async def test_rerank_basic_functionality():
adapter = create_adapter()
mock_response = MockResponse(json_data={"rankings": [{"index": 0, "logit": 0.5}]})
mock_session = MockSession(mock_response)
with patch("aiohttp.ClientSession", return_value=mock_session):
result = await adapter.rerank(model="test-model", query="test query", items=["item1", "item2"])
assert len(result.data) == 1
assert result.data[0].index == 0
assert result.data[0].relevance_score == 0.5
url, kwargs = mock_session.post_calls[0]
payload = kwargs["json"]
assert payload["model"] == "test-model"
assert payload["query"] == {"text": "test query"}
assert payload["passages"] == [{"text": "item1"}, {"text": "item2"}]
async def test_missing_rankings_key():
adapter = create_adapter()
mock_session = MockSession(MockResponse(json_data={}))
with patch("aiohttp.ClientSession", return_value=mock_session):
result = await adapter.rerank(model="test-model", query="q", items=["a"])
assert len(result.data) == 0
async def test_hosted_with_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"test-model": "https://model.endpoint/rerank"}
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert url == "https://model.endpoint/rerank"
async def test_hosted_without_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com).
rerank_endpoints={}, # No endpoint mapping for test-model
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert "https://integrate.api.nvidia.com" in url
async def test_hosted_model_not_in_endpoint_mapping():
adapter = create_adapter(
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"other-model": "https://other.endpoint/rerank"}
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert "https://integrate.api.nvidia.com" in url
assert url != "https://other.endpoint/rerank"
async def test_self_hosted_ignores_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert "http://localhost:8000" in url
assert "model.endpoint/rerank" not in url
async def test_max_num_results():
adapter = create_adapter()
rankings = [{"index": 0, "logit": 0.8}, {"index": 1, "logit": 0.6}]
mock_session = MockSession(MockResponse(json_data={"rankings": rankings}))
with patch("aiohttp.ClientSession", return_value=mock_session):
result = await adapter.rerank(model="test-model", query="q", items=["a", "b"], max_num_results=1)
assert len(result.data) == 1
assert result.data[0].index == 0
assert result.data[0].relevance_score == 0.8
async def test_http_error():
adapter = create_adapter()
mock_session = MockSession(MockResponse(status=500, text_data="Server Error"))
with patch("aiohttp.ClientSession", return_value=mock_session):
with pytest.raises(ConnectionError, match="status 500.*Server Error"):
await adapter.rerank(model="test-model", query="q", items=["a"])
async def test_client_error():
adapter = create_adapter()
mock_session = AsyncMock()
mock_session.__aenter__.side_effect = aiohttp.ClientError("Network error")
with patch("aiohttp.ClientSession", return_value=mock_session):
with pytest.raises(ConnectionError, match="Failed to connect.*Network error"):
await adapter.rerank(model="test-model", query="q", items=["a"])
async def test_list_models_includes_configured_rerank_models():
"""Test that list_models adds rerank models to the dynamic model list."""
adapter = create_adapter()
adapter.__provider_id__ = "nvidia"
adapter.__provider_spec__ = MagicMock()
dynamic_ids = ["llm-1", "embedding-1"]
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(return_value=dynamic_ids)):
result = await adapter.list_models()
assert result is not None
# Check that the rerank models are added
model_ids = [m.identifier for m in result]
assert "nv-rerank-qa-mistral-4b:1" in model_ids
assert "nvidia/nv-rerankqa-mistral-4b-v3" in model_ids
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in model_ids
rerank_models = [m for m in result if m.model_type == ModelType.rerank]
assert len(rerank_models) == 3
for m in rerank_models:
assert m.provider_id == "nvidia"
assert m.model_type == ModelType.rerank
assert m.metadata == {}
assert m.identifier in adapter._model_cache
async def test_list_provider_model_ids_has_no_duplicates():
adapter = create_adapter()
dynamic_ids = [
"llm-1",
"nvidia/nv-rerankqa-mistral-4b-v3", # overlaps configured rerank ids
"embedding-1",
"llm-1",
]
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(return_value=dynamic_ids)):
ids = list(await adapter.list_provider_model_ids())
assert len(ids) == len(set(ids))
assert ids.count("nvidia/nv-rerankqa-mistral-4b-v3") == 1
assert "nv-rerank-qa-mistral-4b:1" in ids
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in ids
async def test_list_provider_model_ids_uses_configured_on_dynamic_failure():
adapter = create_adapter()
# Simulate dynamic listing failure
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(side_effect=Exception)):
ids = list(await adapter.list_provider_model_ids())
# Should still return configured rerank ids
configured_ids = list(adapter.config.rerank_model_to_url.keys())
assert set(ids) == set(configured_ids)