Add rerank API for NVIDIA Inference Provider

This commit is contained in:
Jiayi 2025-09-03 17:34:05 -07:00
parent ce77c27ff8
commit bab9d7aaea
9 changed files with 9213 additions and 1 deletions

View file

@ -4,6 +4,7 @@ description: "Llama Stack Inference API for generating completions, chat complet
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. - LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search." - Embedding models: these models generate embeddings to be used for semantic search."
- Rerank models: these models rerank the documents by relevance."
sidebar_label: Inference sidebar_label: Inference
title: Inference title: Inference
--- ---
@ -17,5 +18,6 @@ Llama Stack Inference API for generating completions, chat completions, and embe
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search. - Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models: these models rerank the documents by relevance.
This section contains documentation for all available providers for the **inference** API. This section contains documentation for all available providers for the **inference** API.

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

257
example.py Normal file
View file

@ -0,0 +1,257 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
os.environ["NVIDIA_API_KEY"] = "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"
# Option 1: Use default NIM URL (will auto-switch to ai.api.nvidia.com for rerank)
# os.environ["NVIDIA_BASE_URL"] = "https://ai.api.nvidia.com"
# Option 2: Use AI Foundation URL directly for rerank models
# os.environ["NVIDIA_BASE_URL"] = "https://ai.api.nvidia.com/v1"
os.environ["NVIDIA_BASE_URL"] = "https://integrate.api.nvidia.com"
import base64
import io
from PIL import Image
from llama_stack.core.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
client.initialize()
# # response = client.inference.completion(
# # model_id="meta/llama-3.1-8b-instruct",
# # content="Complete the sentence using one word: Roses are red, violets are :",
# # stream=False,
# # sampling_params={
# # "max_tokens": 50,
# # },
# # )
# # print(f"Response: {response.content}")
# response = client.inference.chat_completion(
# model_id="nvidia/nvidia-nemotron-nano-9b-v2",
# messages=[
# {
# "role": "system",
# "content": "/think",
# },
# {
# "role": "user",
# "content": "How are you?",
# },
# ],
# stream=False,
# sampling_params={
# "max_tokens": 1024,
# },
# )
# print(f"Response: {response}")
print(client.models.list())
rerank_response = client.inference.rerank(
model="nvidia/llama-3.2-nv-rerankqa-1b-v2",
query="query",
items=[
"item_1",
"item_2",
"item_3",
]
)
print(rerank_response)
for i, result in enumerate(rerank_response):
print(f"{i+1}. [Index: {result.index}, "
f"Score: {(result.relevance_score):.3f}]")
# # from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
# # tool_definition = ToolDefinition(
# # tool_name="get_weather",
# # description="Get current weather information for a location",
# # parameters={
# # "location": ToolParamDefinition(
# # param_type="string",
# # description="The city and state, e.g. San Francisco, CA",
# # required=True
# # ),
# # "unit": ToolParamDefinition(
# # param_type="string",
# # description="Temperature unit (celsius or fahrenheit)",
# # required=False,
# # default="celsius"
# # )
# # }
# # )
# # # tool_response = client.inference.chat_completion(
# # # model_id="meta-llama/Llama-3.1-8B-Instruct",
# # # messages=[
# # # {"role": "user", "content": "What's the weather like in San Francisco?"}
# # # ],
# # # tools=[tool_definition],
# # # )
# # # print(f"Tool Response: {tool_response.completion_message.content}")
# # # if tool_response.completion_message.tool_calls:
# # # for tool_call in tool_response.completion_message.tool_calls:
# # # print(f"Tool Called: {tool_call.tool_name}")
# # # print(f"Arguments: {tool_call.arguments}")
# # # from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
# # # person_schema = {
# # # "type": "object",
# # # "properties": {
# # # "name": {"type": "string"},
# # # "age": {"type": "integer"},
# # # "occupation": {"type": "string"},
# # # },
# # # "required": ["name", "age", "occupation"]
# # # }
# # # response_format = JsonSchemaResponseFormat(
# # # type=ResponseFormatType.json_schema,
# # # json_schema=person_schema
# # # )
# # # structured_response = client.inference.chat_completion(
# # # model_id="meta-llama/Llama-3.1-8B-Instruct",
# # # messages=[
# # # {
# # # "role": "user",
# # # "content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. "
# # # }
# # # ],
# # # response_format=response_format,
# # # )
# # # print(f"Structured Response: {structured_response.completion_message.content}")
# # # print("\n" + "="*50)
# # # print("VISION LANGUAGE MODEL (VLM) EXAMPLE")
# # # print("="*50)
# # def load_image_as_base64(image_path):
# # with open(image_path, "rb") as image_file:
# # img_bytes = image_file.read()
# # return base64.b64encode(img_bytes).decode("utf-8")
# # image_path = "/home/jiayin/llama-stack/docs/dog.jpg"
# # demo_image_b64 = load_image_as_base64(image_path)
# # vlm_response = client.inference.chat_completion(
# # model_id="nvidia/vila",
# # messages=[
# # {
# # "role": "user",
# # "content": [
# # {
# # "type": "image",
# # "image": {
# # "data": demo_image_b64,
# # },
# # },
# # {
# # "type": "text",
# # "text": "Please describe what you see in this image in detail.",
# # },
# # ],
# # }
# # ],
# # )
# # print(f"VLM Response: {vlm_response.completion_message.content}")
# # # print("\n" + "="*50)
# # # print("EMBEDDING EXAMPLE")
# # # print("="*50)
# # # # Embedding example
# # # embedding_response = client.inference.embeddings(
# # # model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
# # # contents=["Hello world", "How are you today?"],
# # # task_type="query"
# # # )
# # # print(f"Number of embeddings: {len(embedding_response.embeddings)}")
# # # print(f"Embedding dimension: {len(embedding_response.embeddings[0])}")
# # # print(f"First few values: {embedding_response.embeddings[0][:5]}")
# # # # from openai import OpenAI
# # # # client = OpenAI(
# # # # base_url = "http://10.176.230.61:8000/v1",
# # # # api_key = "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"
# # # # )
# # # # # completion = client.completions.create(
# # # # # model="meta/llama-3.1-405b-instruct",
# # # # # prompt="How are you?",
# # # # # temperature=0.2,
# # # # # top_p=0.7,
# # # # # max_tokens=1024,
# # # # # stream=False
# # # # # )
# # # # # # completion = client.chat.completions.create(
# # # # # # model="meta/llama-3.1-8b-instruct",
# # # # # # messages=[{"role":"user","content":"hi"}],
# # # # # # temperature=0.2,
# # # # # # top_p=0.7,
# # # # # # max_tokens=1024,
# # # # # # stream=True
# # # # # # )
# # # # # for chunk in completion:
# # # # # if chunk.choices[0].delta.content is not None:
# # # # # print(chunk.choices[0].delta.content, end="")
# # # # # response = client.inference.completion(
# # # # # model_id="meta/llama-3.1-8b-instruct",
# # # # # content="Complete the sentence using one word: Roses are red, violets are :",
# # # # # stream=False,
# # # # # sampling_params={
# # # # # "max_tokens": 50,
# # # # # },
# # # # # )
# # # # # print(f"Response: {response.content}")
# from openai import OpenAI
# client = OpenAI(
# base_url = "https://integrate.api.nvidia.com/v1",
# api_key = "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"
# )
# completion = client.chat.completions.create(
# model="nvidia/nvidia-nemotron-nano-9b-v2",
# messages=[{"role":"system","content":"/think"}],
# temperature=0.6,
# top_p=0.95,
# max_tokens=2048,
# frequency_penalty=0,
# presence_penalty=0,
# stream=True,
# extra_body={
# "min_thinking_tokens": 1024,
# "max_thinking_tokens": 2048
# }
# )
# for chunk in completion:
# reasoning = getattr(chunk.choices[0].delta, "reasoning_content", None)
# if reasoning:
# print(reasoning, end="")
# if chunk.choices[0].delta.content is not None:
# print(chunk.choices[0].delta.content, end="")

View file

@ -1016,7 +1016,7 @@ class InferenceProvider(Protocol):
) -> RerankResponse: ) -> RerankResponse:
"""Rerank a list of documents based on their relevance to a query. """Rerank a list of documents based on their relevance to a query.
:param model: The identifier of the reranking model to use. :param model: The identifier of the reranking model to use. The model must be a reranking model registered with Llama Stack and available via the /models endpoint.
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length. :param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length. :param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all. :param max_num_results: (Optional) Maximum number of results to return. Default: returns all.

View file

@ -27,10 +27,12 @@ class ModelType(StrEnum):
"""Enumeration of supported model types in Llama Stack. """Enumeration of supported model types in Llama Stack.
:cvar llm: Large language model for text generation and completion :cvar llm: Large language model for text generation and completion
:cvar embedding: Embedding model for converting text to vector representations :cvar embedding: Embedding model for converting text to vector representations
:cvar rerank: Reranking model for reordering documents by relevance
""" """
llm = "llm" llm = "llm"
embedding = "embedding" embedding = "embedding"
rerank = "rerank"
@json_schema_type @json_schema_type

View file

@ -41,9 +41,14 @@ from llama_stack.apis.inference import (
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
Order, Order,
RerankResponse,
StopReason, StopReason,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -179,6 +184,25 @@ class InferenceRouter(Inference):
raise ModelTypeError(model_id, model.model_type, expected_model_type) raise ModelTypeError(model_id, model.model_type, expected_model_type)
return model return model
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
"""Route rerank requests to the appropriate provider based on the model."""
logger.debug(f"InferenceRouter.rerank: {model}")
model_obj = await self._get_model(model, ModelType.rerank)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.rerank(
model=model_obj.identifier,
query=query,
items=items,
max_num_results=max_num_results,
)
async def openai_completion( async def openai_completion(
self, self,
model: str, model: str,

View file

@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = []
# https://docs.nvidia.com/nim/large-language-models/latest/supported-llm-agnostic-architectures.html
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta/llama3-8b-instruct",
CoreModelId.llama3_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama3-70b-instruct",
CoreModelId.llama3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta/llama-3.3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
ProviderModelEntry(
provider_model_id="nvidia/vila",
model_type=ModelType.llm,
),
# NeMo Retriever Text Embedding models -
#
# https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
#
# +-----------------------------------+--------+-----------+-----------+------------+
# | Model ID | Max | Publisher | Embedding | Dynamic |
# | | Tokens | | Dimension | Embeddings |
# +-----------------------------------+--------+-----------+-----------+------------+
# | nvidia/llama-3.2-nv-embedqa-1b-v2 | 8192 | NVIDIA | 2048 | Yes |
# | nvidia/nv-embedqa-e5-v5 | 512 | NVIDIA | 1024 | No |
# | nvidia/nv-embedqa-mistral-7b-v2 | 512 | NVIDIA | 4096 | No |
# | snowflake/arctic-embed-l | 512 | Snowflake | 1024 | No |
# +-----------------------------------+--------+-----------+-----------+------------+
ProviderModelEntry(
provider_model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 2048,
"context_length": 8192,
},
),
ProviderModelEntry(
provider_model_id="nvidia/nv-embedqa-e5-v5",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
ProviderModelEntry(
provider_model_id="nvidia/nv-embedqa-mistral-7b-v2",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 4096,
"context_length": 512,
},
),
ProviderModelEntry(
provider_model_id="snowflake/arctic-embed-l",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 1024,
"context_length": 512,
},
),
# NVIDIA Reranking models
ProviderModelEntry(
provider_model_id="nv-rerank-qa-mistral-4b:1",
model_type=ModelType.rerank,
metadata={
"endpoint": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
},
),
ProviderModelEntry(
provider_model_id="nvidia/nv-rerankqa-mistral-4b-v3",
model_type=ModelType.rerank,
metadata={
"endpoint": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
},
),
ProviderModelEntry(
provider_model_id="nvidia/llama-3.2-nv-rerankqa-1b-v2",
model_type=ModelType.rerank,
metadata={
"endpoint": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
},
),
# TODO(mf): how do we handle Nemotron models?
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
] + SAFETY_MODELS_ENTRIES

View file

@ -12,6 +12,12 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingData, OpenAIEmbeddingData,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage, OpenAIEmbeddingUsage,
RerankData,
RerankResponse,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -80,6 +86,80 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
""" """
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
provider_model_id = await self._get_provider_model_id(model)
ranking_url = self.get_base_url()
model_obj = await self.model_store.get_model(model)
if _is_nvidia_hosted(self._config) and "endpoint" in model_obj.metadata:
ranking_url = model_obj.metadata["endpoint"]
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
# Convert query to text format
if isinstance(query, str):
query_text = query
elif hasattr(query, "text"):
query_text = query.text
else:
raise ValueError("Query must be a string or text content part")
# Convert items to text format
passages = []
for item in items:
if isinstance(item, str):
passages.append({"text": item})
elif hasattr(item, "text"):
passages.append({"text": item.text})
else:
raise ValueError("Items must be strings or text content parts")
payload = {
"model": provider_model_id,
"query": {"text": query_text},
"passages": passages,
}
headers = {
"Authorization": f"Bearer {self.get_api_key()}",
"Content-Type": "application/json",
}
import aiohttp
try:
async with aiohttp.ClientSession() as session:
async with session.post(ranking_url, headers=headers, json=payload) as response:
if response.status != 200:
response_text = await response.text()
raise ConnectionError(
f"NVIDIA rerank API request failed with status {response.status}: {response_text}"
)
result = await response.json()
rankings = result.get("rankings", [])
# Convert to RerankData format
rerank_data = []
for ranking in rankings:
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
# Apply max_num_results limit if specified
if max_num_results is not None:
rerank_data = rerank_data[:max_num_results]
return RerankResponse(data=rerank_data)
except aiohttp.ClientError as e:
raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e
async def openai_embeddings( async def openai_embeddings(
self, self,
model: str, model: str,