feat: Add rerank API for NVIDIA Inference Provider (#3329)

# What does this PR do?
Add rerank API for NVIDIA Inference Provider.

<!-- If resolving an issue, uncomment and update the line below -->
Closes #3278 

## Test Plan
Unit test:
```
pytest tests/unit/providers/nvidia/test_rerank_inference.py
```

Integration test: 
```
pytest -s -v tests/integration/inference/test_rerank.py   --stack-config="inference=nvidia"   --rerank-model=nvidia/nvidia/nv-rerankqa-mistral-4b-v3   --env NVIDIA_API_KEY=""   --env NVIDIA_BASE_URL="https://integrate.api.nvidia.com"
```
This commit is contained in:
Jiayi Ni 2025-10-30 21:42:09 -07:00 committed by GitHub
parent c396de57a4
commit fa7699d2c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 622 additions and 1 deletions

View file

@ -5,6 +5,19 @@
# the root directory of this source tree.
from collections.abc import Iterable
import aiohttp
from llama_stack.apis.inference import (
RerankData,
RerankResponse,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -61,3 +74,101 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
:return: The NVIDIA API base URL
"""
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
async def list_provider_model_ids(self) -> Iterable[str]:
"""
Return both dynamic model IDs and statically configured rerank model IDs.
"""
dynamic_ids: Iterable[str] = []
try:
dynamic_ids = await super().list_provider_model_ids()
except Exception:
# If the dynamic listing fails, proceed with just configured rerank IDs
dynamic_ids = []
configured_rerank_ids = list(self.config.rerank_model_to_url.keys())
return list(dict.fromkeys(list(dynamic_ids) + configured_rerank_ids)) # remove duplicates
def construct_model_from_identifier(self, identifier: str) -> Model:
"""
Classify rerank models from config; otherwise use the base behavior.
"""
if identifier in self.config.rerank_model_to_url:
return Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=identifier,
identifier=identifier,
model_type=ModelType.rerank,
)
return super().construct_model_from_identifier(identifier)
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
provider_model_id = await self._get_provider_model_id(model)
ranking_url = self.get_base_url()
if _is_nvidia_hosted(self.config) and provider_model_id in self.config.rerank_model_to_url:
ranking_url = self.config.rerank_model_to_url[provider_model_id]
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
# Convert query to text format
if isinstance(query, str):
query_text = query
elif isinstance(query, OpenAIChatCompletionContentPartTextParam):
query_text = query.text
else:
raise ValueError("Query must be a string or text content part")
# Convert items to text format
passages = []
for item in items:
if isinstance(item, str):
passages.append({"text": item})
elif isinstance(item, OpenAIChatCompletionContentPartTextParam):
passages.append({"text": item.text})
else:
raise ValueError("Items must be strings or text content parts")
payload = {
"model": provider_model_id,
"query": {"text": query_text},
"passages": passages,
}
headers = {
"Authorization": f"Bearer {self.get_api_key()}",
"Content-Type": "application/json",
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(ranking_url, headers=headers, json=payload) as response:
if response.status != 200:
response_text = await response.text()
raise ConnectionError(
f"NVIDIA rerank API request failed with status {response.status}: {response_text}"
)
result = await response.json()
rankings = result.get("rankings", [])
# Convert to RerankData format
rerank_data = []
for ranking in rankings:
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
# Apply max_num_results limit
if max_num_results is not None:
rerank_data = rerank_data[:max_num_results]
return RerankResponse(data=rerank_data)
except aiohttp.ClientError as e:
raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e