mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat: Add rerank API for NVIDIA Inference Provider (#3329)
# What does this PR do? Add rerank API for NVIDIA Inference Provider. <!-- If resolving an issue, uncomment and update the line below --> Closes #3278 ## Test Plan Unit test: ``` pytest tests/unit/providers/nvidia/test_rerank_inference.py ``` Integration test: ``` pytest -s -v tests/integration/inference/test_rerank.py --stack-config="inference=nvidia" --rerank-model=nvidia/nvidia/nv-rerankqa-mistral-4b-v3 --env NVIDIA_API_KEY="" --env NVIDIA_BASE_URL="https://integrate.api.nvidia.com" ```
This commit is contained in:
parent
c396de57a4
commit
fa7699d2c3
8 changed files with 622 additions and 1 deletions
|
|
@ -181,3 +181,22 @@ vlm_response = client.chat.completions.create(
|
|||
|
||||
print(f"VLM Response: {vlm_response.choices[0].message.content}")
|
||||
```
|
||||
|
||||
### Rerank Example
|
||||
|
||||
The following example shows how to rerank documents using an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
rerank_response = client.alpha.inference.rerank(
|
||||
model="nvidia/nvidia/llama-3.2-nv-rerankqa-1b-v2",
|
||||
query="query",
|
||||
items=[
|
||||
"item_1",
|
||||
"item_2",
|
||||
"item_3",
|
||||
],
|
||||
)
|
||||
|
||||
for i, result in enumerate(rerank_response):
|
||||
print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]")
|
||||
```
|
||||
|
|
@ -28,6 +28,7 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
|||
Attributes:
|
||||
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
|
||||
api_key (str): The access key for the hosted NIM endpoints
|
||||
rerank_model_to_url (dict[str, str]): Mapping of rerank model identifiers to their API endpoints
|
||||
|
||||
There are two ways to access NVIDIA NIMs -
|
||||
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
|
||||
|
|
@ -55,6 +56,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
|||
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
||||
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||
)
|
||||
rerank_model_to_url: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
|
||||
"nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
|
||||
"nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
|
||||
},
|
||||
description="Mapping of rerank model identifiers to their API endpoints. ",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
|
|
|
|||
|
|
@ -5,6 +5,19 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import aiohttp
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
RerankData,
|
||||
RerankResponse,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
|
@ -61,3 +74,101 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
:return: The NVIDIA API base URL
|
||||
"""
|
||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
"""
|
||||
Return both dynamic model IDs and statically configured rerank model IDs.
|
||||
"""
|
||||
dynamic_ids: Iterable[str] = []
|
||||
try:
|
||||
dynamic_ids = await super().list_provider_model_ids()
|
||||
except Exception:
|
||||
# If the dynamic listing fails, proceed with just configured rerank IDs
|
||||
dynamic_ids = []
|
||||
|
||||
configured_rerank_ids = list(self.config.rerank_model_to_url.keys())
|
||||
return list(dict.fromkeys(list(dynamic_ids) + configured_rerank_ids)) # remove duplicates
|
||||
|
||||
def construct_model_from_identifier(self, identifier: str) -> Model:
|
||||
"""
|
||||
Classify rerank models from config; otherwise use the base behavior.
|
||||
"""
|
||||
if identifier in self.config.rerank_model_to_url:
|
||||
return Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=identifier,
|
||||
identifier=identifier,
|
||||
model_type=ModelType.rerank,
|
||||
)
|
||||
return super().construct_model_from_identifier(identifier)
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
provider_model_id = await self._get_provider_model_id(model)
|
||||
|
||||
ranking_url = self.get_base_url()
|
||||
|
||||
if _is_nvidia_hosted(self.config) and provider_model_id in self.config.rerank_model_to_url:
|
||||
ranking_url = self.config.rerank_model_to_url[provider_model_id]
|
||||
|
||||
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
|
||||
|
||||
# Convert query to text format
|
||||
if isinstance(query, str):
|
||||
query_text = query
|
||||
elif isinstance(query, OpenAIChatCompletionContentPartTextParam):
|
||||
query_text = query.text
|
||||
else:
|
||||
raise ValueError("Query must be a string or text content part")
|
||||
|
||||
# Convert items to text format
|
||||
passages = []
|
||||
for item in items:
|
||||
if isinstance(item, str):
|
||||
passages.append({"text": item})
|
||||
elif isinstance(item, OpenAIChatCompletionContentPartTextParam):
|
||||
passages.append({"text": item.text})
|
||||
else:
|
||||
raise ValueError("Items must be strings or text content parts")
|
||||
|
||||
payload = {
|
||||
"model": provider_model_id,
|
||||
"query": {"text": query_text},
|
||||
"passages": passages,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.get_api_key()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(ranking_url, headers=headers, json=payload) as response:
|
||||
if response.status != 200:
|
||||
response_text = await response.text()
|
||||
raise ConnectionError(
|
||||
f"NVIDIA rerank API request failed with status {response.status}: {response_text}"
|
||||
)
|
||||
|
||||
result = await response.json()
|
||||
rankings = result.get("rankings", [])
|
||||
|
||||
# Convert to RerankData format
|
||||
rerank_data = []
|
||||
for ranking in rankings:
|
||||
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
|
||||
|
||||
# Apply max_num_results limit
|
||||
if max_num_results is not None:
|
||||
rerank_data = rerank_data[:max_num_results]
|
||||
|
||||
return RerankResponse(data=rerank_data)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue