Add rerank API for NVIDIA Inference Provider

This commit is contained in:
Jiayi 2025-09-03 17:34:05 -07:00
parent ce77c27ff8
commit bab9d7aaea
9 changed files with 9213 additions and 1 deletions

View file

@ -12,6 +12,12 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
RerankData,
RerankResponse,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -80,6 +86,80 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
"""
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
provider_model_id = await self._get_provider_model_id(model)
ranking_url = self.get_base_url()
model_obj = await self.model_store.get_model(model)
if _is_nvidia_hosted(self._config) and "endpoint" in model_obj.metadata:
ranking_url = model_obj.metadata["endpoint"]
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
# Convert query to text format
if isinstance(query, str):
query_text = query
elif hasattr(query, "text"):
query_text = query.text
else:
raise ValueError("Query must be a string or text content part")
# Convert items to text format
passages = []
for item in items:
if isinstance(item, str):
passages.append({"text": item})
elif hasattr(item, "text"):
passages.append({"text": item.text})
else:
raise ValueError("Items must be strings or text content parts")
payload = {
"model": provider_model_id,
"query": {"text": query_text},
"passages": passages,
}
headers = {
"Authorization": f"Bearer {self.get_api_key()}",
"Content-Type": "application/json",
}
import aiohttp
try:
async with aiohttp.ClientSession() as session:
async with session.post(ranking_url, headers=headers, json=payload) as response:
if response.status != 200:
response_text = await response.text()
raise ConnectionError(
f"NVIDIA rerank API request failed with status {response.status}: {response_text}"
)
result = await response.json()
rankings = result.get("rankings", [])
# Convert to RerankData format
rerank_data = []
for ranking in rankings:
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
# Apply max_num_results limit if specified
if max_num_results is not None:
rerank_data = rerank_data[:max_num_results]
return RerankResponse(data=rerank_data)
except aiohttp.ClientError as e:
raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e
async def openai_embeddings(
self,
model: str,