mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Add rerank API for NVIDIA Inference Provider
This commit is contained in:
parent
ce77c27ff8
commit
bab9d7aaea
9 changed files with 9213 additions and 1 deletions
|
@ -4,6 +4,7 @@ description: "Llama Stack Inference API for generating completions, chat complet
|
||||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||||
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
|
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
|
||||||
- Embedding models: these models generate embeddings to be used for semantic search."
|
- Embedding models: these models generate embeddings to be used for semantic search."
|
||||||
|
- Rerank models: these models rerank the documents by relevance."
|
||||||
sidebar_label: Inference
|
sidebar_label: Inference
|
||||||
title: Inference
|
title: Inference
|
||||||
---
|
---
|
||||||
|
@ -17,5 +18,6 @@ Llama Stack Inference API for generating completions, chat completions, and embe
|
||||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||||
|
- Rerank models: these models rerank the documents by relevance.
|
||||||
|
|
||||||
This section contains documentation for all available providers for the **inference** API.
|
This section contains documentation for all available providers for the **inference** API.
|
||||||
|
|
4992
docs/static/llama-stack-spec.html
vendored
4992
docs/static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
3724
docs/static/llama-stack-spec.yaml
vendored
3724
docs/static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
257
example.py
Normal file
257
example.py
Normal file
|
@ -0,0 +1,257 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["NVIDIA_API_KEY"] = "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"
|
||||||
|
# Option 1: Use default NIM URL (will auto-switch to ai.api.nvidia.com for rerank)
|
||||||
|
# os.environ["NVIDIA_BASE_URL"] = "https://ai.api.nvidia.com"
|
||||||
|
# Option 2: Use AI Foundation URL directly for rerank models
|
||||||
|
# os.environ["NVIDIA_BASE_URL"] = "https://ai.api.nvidia.com/v1"
|
||||||
|
os.environ["NVIDIA_BASE_URL"] = "https://integrate.api.nvidia.com"
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
|
client = LlamaStackAsLibraryClient("nvidia")
|
||||||
|
client.initialize()
|
||||||
|
|
||||||
|
# # response = client.inference.completion(
|
||||||
|
# # model_id="meta/llama-3.1-8b-instruct",
|
||||||
|
# # content="Complete the sentence using one word: Roses are red, violets are :",
|
||||||
|
# # stream=False,
|
||||||
|
# # sampling_params={
|
||||||
|
# # "max_tokens": 50,
|
||||||
|
# # },
|
||||||
|
# # )
|
||||||
|
# # print(f"Response: {response.content}")
|
||||||
|
|
||||||
|
|
||||||
|
# response = client.inference.chat_completion(
|
||||||
|
# model_id="nvidia/nvidia-nemotron-nano-9b-v2",
|
||||||
|
# messages=[
|
||||||
|
# {
|
||||||
|
# "role": "system",
|
||||||
|
# "content": "/think",
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": "How are you?",
|
||||||
|
# },
|
||||||
|
# ],
|
||||||
|
# stream=False,
|
||||||
|
# sampling_params={
|
||||||
|
# "max_tokens": 1024,
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
# print(f"Response: {response}")
|
||||||
|
|
||||||
|
|
||||||
|
print(client.models.list())
|
||||||
|
rerank_response = client.inference.rerank(
|
||||||
|
model="nvidia/llama-3.2-nv-rerankqa-1b-v2",
|
||||||
|
query="query",
|
||||||
|
items=[
|
||||||
|
"item_1",
|
||||||
|
"item_2",
|
||||||
|
"item_3",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(rerank_response)
|
||||||
|
for i, result in enumerate(rerank_response):
|
||||||
|
print(f"{i+1}. [Index: {result.index}, "
|
||||||
|
f"Score: {(result.relevance_score):.3f}]")
|
||||||
|
|
||||||
|
# # from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||||
|
|
||||||
|
# # tool_definition = ToolDefinition(
|
||||||
|
# # tool_name="get_weather",
|
||||||
|
# # description="Get current weather information for a location",
|
||||||
|
# # parameters={
|
||||||
|
# # "location": ToolParamDefinition(
|
||||||
|
# # param_type="string",
|
||||||
|
# # description="The city and state, e.g. San Francisco, CA",
|
||||||
|
# # required=True
|
||||||
|
# # ),
|
||||||
|
# # "unit": ToolParamDefinition(
|
||||||
|
# # param_type="string",
|
||||||
|
# # description="Temperature unit (celsius or fahrenheit)",
|
||||||
|
# # required=False,
|
||||||
|
# # default="celsius"
|
||||||
|
# # )
|
||||||
|
# # }
|
||||||
|
# # )
|
||||||
|
|
||||||
|
# # # tool_response = client.inference.chat_completion(
|
||||||
|
# # # model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
# # # messages=[
|
||||||
|
# # # {"role": "user", "content": "What's the weather like in San Francisco?"}
|
||||||
|
# # # ],
|
||||||
|
# # # tools=[tool_definition],
|
||||||
|
# # # )
|
||||||
|
|
||||||
|
# # # print(f"Tool Response: {tool_response.completion_message.content}")
|
||||||
|
# # # if tool_response.completion_message.tool_calls:
|
||||||
|
# # # for tool_call in tool_response.completion_message.tool_calls:
|
||||||
|
# # # print(f"Tool Called: {tool_call.tool_name}")
|
||||||
|
# # # print(f"Arguments: {tool_call.arguments}")
|
||||||
|
|
||||||
|
|
||||||
|
# # # from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
|
||||||
|
|
||||||
|
# # # person_schema = {
|
||||||
|
# # # "type": "object",
|
||||||
|
# # # "properties": {
|
||||||
|
# # # "name": {"type": "string"},
|
||||||
|
# # # "age": {"type": "integer"},
|
||||||
|
# # # "occupation": {"type": "string"},
|
||||||
|
# # # },
|
||||||
|
# # # "required": ["name", "age", "occupation"]
|
||||||
|
# # # }
|
||||||
|
|
||||||
|
# # # response_format = JsonSchemaResponseFormat(
|
||||||
|
# # # type=ResponseFormatType.json_schema,
|
||||||
|
# # # json_schema=person_schema
|
||||||
|
# # # )
|
||||||
|
|
||||||
|
# # # structured_response = client.inference.chat_completion(
|
||||||
|
# # # model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
# # # messages=[
|
||||||
|
# # # {
|
||||||
|
# # # "role": "user",
|
||||||
|
# # # "content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. "
|
||||||
|
# # # }
|
||||||
|
# # # ],
|
||||||
|
# # # response_format=response_format,
|
||||||
|
# # # )
|
||||||
|
|
||||||
|
# # # print(f"Structured Response: {structured_response.completion_message.content}")
|
||||||
|
|
||||||
|
# # # print("\n" + "="*50)
|
||||||
|
# # # print("VISION LANGUAGE MODEL (VLM) EXAMPLE")
|
||||||
|
# # # print("="*50)
|
||||||
|
|
||||||
|
# # def load_image_as_base64(image_path):
|
||||||
|
# # with open(image_path, "rb") as image_file:
|
||||||
|
# # img_bytes = image_file.read()
|
||||||
|
# # return base64.b64encode(img_bytes).decode("utf-8")
|
||||||
|
|
||||||
|
# # image_path = "/home/jiayin/llama-stack/docs/dog.jpg"
|
||||||
|
# # demo_image_b64 = load_image_as_base64(image_path)
|
||||||
|
|
||||||
|
# # vlm_response = client.inference.chat_completion(
|
||||||
|
# # model_id="nvidia/vila",
|
||||||
|
# # messages=[
|
||||||
|
# # {
|
||||||
|
# # "role": "user",
|
||||||
|
# # "content": [
|
||||||
|
# # {
|
||||||
|
# # "type": "image",
|
||||||
|
# # "image": {
|
||||||
|
# # "data": demo_image_b64,
|
||||||
|
# # },
|
||||||
|
# # },
|
||||||
|
# # {
|
||||||
|
# # "type": "text",
|
||||||
|
# # "text": "Please describe what you see in this image in detail.",
|
||||||
|
# # },
|
||||||
|
# # ],
|
||||||
|
# # }
|
||||||
|
# # ],
|
||||||
|
# # )
|
||||||
|
|
||||||
|
# # print(f"VLM Response: {vlm_response.completion_message.content}")
|
||||||
|
|
||||||
|
# # # print("\n" + "="*50)
|
||||||
|
# # # print("EMBEDDING EXAMPLE")
|
||||||
|
# # # print("="*50)
|
||||||
|
|
||||||
|
# # # # Embedding example
|
||||||
|
# # # embedding_response = client.inference.embeddings(
|
||||||
|
# # # model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||||
|
# # # contents=["Hello world", "How are you today?"],
|
||||||
|
# # # task_type="query"
|
||||||
|
# # # )
|
||||||
|
|
||||||
|
# # # print(f"Number of embeddings: {len(embedding_response.embeddings)}")
|
||||||
|
# # # print(f"Embedding dimension: {len(embedding_response.embeddings[0])}")
|
||||||
|
# # # print(f"First few values: {embedding_response.embeddings[0][:5]}")
|
||||||
|
|
||||||
|
# # # # from openai import OpenAI
|
||||||
|
|
||||||
|
# # # # client = OpenAI(
|
||||||
|
# # # # base_url = "http://10.176.230.61:8000/v1",
|
||||||
|
# # # # api_key = "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"
|
||||||
|
# # # # )
|
||||||
|
|
||||||
|
# # # # # completion = client.completions.create(
|
||||||
|
# # # # # model="meta/llama-3.1-405b-instruct",
|
||||||
|
# # # # # prompt="How are you?",
|
||||||
|
# # # # # temperature=0.2,
|
||||||
|
# # # # # top_p=0.7,
|
||||||
|
# # # # # max_tokens=1024,
|
||||||
|
# # # # # stream=False
|
||||||
|
# # # # # )
|
||||||
|
|
||||||
|
# # # # # # completion = client.chat.completions.create(
|
||||||
|
# # # # # # model="meta/llama-3.1-8b-instruct",
|
||||||
|
# # # # # # messages=[{"role":"user","content":"hi"}],
|
||||||
|
# # # # # # temperature=0.2,
|
||||||
|
# # # # # # top_p=0.7,
|
||||||
|
# # # # # # max_tokens=1024,
|
||||||
|
# # # # # # stream=True
|
||||||
|
# # # # # # )
|
||||||
|
|
||||||
|
# # # # # for chunk in completion:
|
||||||
|
# # # # # if chunk.choices[0].delta.content is not None:
|
||||||
|
# # # # # print(chunk.choices[0].delta.content, end="")
|
||||||
|
|
||||||
|
|
||||||
|
# # # # # response = client.inference.completion(
|
||||||
|
# # # # # model_id="meta/llama-3.1-8b-instruct",
|
||||||
|
# # # # # content="Complete the sentence using one word: Roses are red, violets are :",
|
||||||
|
# # # # # stream=False,
|
||||||
|
# # # # # sampling_params={
|
||||||
|
# # # # # "max_tokens": 50,
|
||||||
|
# # # # # },
|
||||||
|
# # # # # )
|
||||||
|
# # # # # print(f"Response: {response.content}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# from openai import OpenAI
|
||||||
|
|
||||||
|
# client = OpenAI(
|
||||||
|
# base_url = "https://integrate.api.nvidia.com/v1",
|
||||||
|
# api_key = "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"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# completion = client.chat.completions.create(
|
||||||
|
# model="nvidia/nvidia-nemotron-nano-9b-v2",
|
||||||
|
# messages=[{"role":"system","content":"/think"}],
|
||||||
|
# temperature=0.6,
|
||||||
|
# top_p=0.95,
|
||||||
|
# max_tokens=2048,
|
||||||
|
# frequency_penalty=0,
|
||||||
|
# presence_penalty=0,
|
||||||
|
# stream=True,
|
||||||
|
# extra_body={
|
||||||
|
# "min_thinking_tokens": 1024,
|
||||||
|
# "max_thinking_tokens": 2048
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
|
||||||
|
# for chunk in completion:
|
||||||
|
# reasoning = getattr(chunk.choices[0].delta, "reasoning_content", None)
|
||||||
|
# if reasoning:
|
||||||
|
# print(reasoning, end="")
|
||||||
|
# if chunk.choices[0].delta.content is not None:
|
||||||
|
# print(chunk.choices[0].delta.content, end="")
|
|
@ -1016,7 +1016,7 @@ class InferenceProvider(Protocol):
|
||||||
) -> RerankResponse:
|
) -> RerankResponse:
|
||||||
"""Rerank a list of documents based on their relevance to a query.
|
"""Rerank a list of documents based on their relevance to a query.
|
||||||
|
|
||||||
:param model: The identifier of the reranking model to use.
|
:param model: The identifier of the reranking model to use. The model must be a reranking model registered with Llama Stack and available via the /models endpoint.
|
||||||
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
|
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
|
||||||
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
|
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
|
||||||
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all.
|
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all.
|
||||||
|
|
|
@ -27,10 +27,12 @@ class ModelType(StrEnum):
|
||||||
"""Enumeration of supported model types in Llama Stack.
|
"""Enumeration of supported model types in Llama Stack.
|
||||||
:cvar llm: Large language model for text generation and completion
|
:cvar llm: Large language model for text generation and completion
|
||||||
:cvar embedding: Embedding model for converting text to vector representations
|
:cvar embedding: Embedding model for converting text to vector representations
|
||||||
|
:cvar rerank: Reranking model for reordering documents by relevance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
llm = "llm"
|
llm = "llm"
|
||||||
embedding = "embedding"
|
embedding = "embedding"
|
||||||
|
rerank = "rerank"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -41,9 +41,14 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
Order,
|
Order,
|
||||||
|
RerankResponse,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.inference.inference import (
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -179,6 +184,25 @@ class InferenceRouter(Inference):
|
||||||
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
"""Route rerank requests to the appropriate provider based on the model."""
|
||||||
|
logger.debug(f"InferenceRouter.rerank: {model}")
|
||||||
|
model_obj = await self._get_model(model, ModelType.rerank)
|
||||||
|
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
|
return await provider.rerank(
|
||||||
|
model=model_obj.identifier,
|
||||||
|
query=query,
|
||||||
|
items=items,
|
||||||
|
max_num_results=max_num_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
131
llama_stack/providers/remote/inference/nvidia/models.py
Normal file
131
llama_stack/providers/remote/inference/nvidia/models.py
Normal file
|
@ -0,0 +1,131 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.models import ModelType
|
||||||
|
from llama_stack.models.llama.sku_types import CoreModelId
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
ProviderModelEntry,
|
||||||
|
build_hf_repo_model_entry,
|
||||||
|
)
|
||||||
|
|
||||||
|
SAFETY_MODELS_ENTRIES = []
|
||||||
|
|
||||||
|
# https://docs.nvidia.com/nim/large-language-models/latest/supported-llm-agnostic-architectures.html
|
||||||
|
MODEL_ENTRIES = [
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta/llama3-8b-instruct",
|
||||||
|
CoreModelId.llama3_8b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta/llama3-70b-instruct",
|
||||||
|
CoreModelId.llama3_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta/llama-3.1-8b-instruct",
|
||||||
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta/llama-3.1-70b-instruct",
|
||||||
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta/llama-3.1-405b-instruct",
|
||||||
|
CoreModelId.llama3_1_405b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta/llama-3.2-1b-instruct",
|
||||||
|
CoreModelId.llama3_2_1b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta/llama-3.2-3b-instruct",
|
||||||
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta/llama-3.2-11b-vision-instruct",
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta/llama-3.2-90b-vision-instruct",
|
||||||
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta/llama-3.3-70b-instruct",
|
||||||
|
CoreModelId.llama3_3_70b_instruct.value,
|
||||||
|
),
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="nvidia/vila",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
),
|
||||||
|
# NeMo Retriever Text Embedding models -
|
||||||
|
#
|
||||||
|
# https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||||
|
#
|
||||||
|
# +-----------------------------------+--------+-----------+-----------+------------+
|
||||||
|
# | Model ID | Max | Publisher | Embedding | Dynamic |
|
||||||
|
# | | Tokens | | Dimension | Embeddings |
|
||||||
|
# +-----------------------------------+--------+-----------+-----------+------------+
|
||||||
|
# | nvidia/llama-3.2-nv-embedqa-1b-v2 | 8192 | NVIDIA | 2048 | Yes |
|
||||||
|
# | nvidia/nv-embedqa-e5-v5 | 512 | NVIDIA | 1024 | No |
|
||||||
|
# | nvidia/nv-embedqa-mistral-7b-v2 | 512 | NVIDIA | 4096 | No |
|
||||||
|
# | snowflake/arctic-embed-l | 512 | Snowflake | 1024 | No |
|
||||||
|
# +-----------------------------------+--------+-----------+-----------+------------+
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 2048,
|
||||||
|
"context_length": 8192,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="nvidia/nv-embedqa-e5-v5",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 1024,
|
||||||
|
"context_length": 512,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="nvidia/nv-embedqa-mistral-7b-v2",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 4096,
|
||||||
|
"context_length": 512,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="snowflake/arctic-embed-l",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 1024,
|
||||||
|
"context_length": 512,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
# NVIDIA Reranking models
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="nv-rerank-qa-mistral-4b:1",
|
||||||
|
model_type=ModelType.rerank,
|
||||||
|
metadata={
|
||||||
|
"endpoint": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="nvidia/nv-rerankqa-mistral-4b-v3",
|
||||||
|
model_type=ModelType.rerank,
|
||||||
|
metadata={
|
||||||
|
"endpoint": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
ProviderModelEntry(
|
||||||
|
provider_model_id="nvidia/llama-3.2-nv-rerankqa-1b-v2",
|
||||||
|
model_type=ModelType.rerank,
|
||||||
|
metadata={
|
||||||
|
"endpoint": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
# TODO(mf): how do we handle Nemotron models?
|
||||||
|
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
|
||||||
|
] + SAFETY_MODELS_ENTRIES
|
|
@ -12,6 +12,12 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIEmbeddingData,
|
OpenAIEmbeddingData,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
|
RerankData,
|
||||||
|
RerankResponse,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.inference.inference import (
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
@ -80,6 +86,80 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
|
||||||
"""
|
"""
|
||||||
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
provider_model_id = await self._get_provider_model_id(model)
|
||||||
|
|
||||||
|
ranking_url = self.get_base_url()
|
||||||
|
model_obj = await self.model_store.get_model(model)
|
||||||
|
|
||||||
|
if _is_nvidia_hosted(self._config) and "endpoint" in model_obj.metadata:
|
||||||
|
ranking_url = model_obj.metadata["endpoint"]
|
||||||
|
|
||||||
|
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
|
||||||
|
|
||||||
|
# Convert query to text format
|
||||||
|
if isinstance(query, str):
|
||||||
|
query_text = query
|
||||||
|
elif hasattr(query, "text"):
|
||||||
|
query_text = query.text
|
||||||
|
else:
|
||||||
|
raise ValueError("Query must be a string or text content part")
|
||||||
|
|
||||||
|
# Convert items to text format
|
||||||
|
passages = []
|
||||||
|
for item in items:
|
||||||
|
if isinstance(item, str):
|
||||||
|
passages.append({"text": item})
|
||||||
|
elif hasattr(item, "text"):
|
||||||
|
passages.append({"text": item.text})
|
||||||
|
else:
|
||||||
|
raise ValueError("Items must be strings or text content parts")
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": provider_model_id,
|
||||||
|
"query": {"text": query_text},
|
||||||
|
"passages": passages,
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.get_api_key()}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(ranking_url, headers=headers, json=payload) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
response_text = await response.text()
|
||||||
|
raise ConnectionError(
|
||||||
|
f"NVIDIA rerank API request failed with status {response.status}: {response_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
rankings = result.get("rankings", [])
|
||||||
|
|
||||||
|
# Convert to RerankData format
|
||||||
|
rerank_data = []
|
||||||
|
for ranking in rankings:
|
||||||
|
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
|
||||||
|
|
||||||
|
# Apply max_num_results limit if specified
|
||||||
|
if max_num_results is not None:
|
||||||
|
rerank_data = rerank_data[:max_num_results]
|
||||||
|
|
||||||
|
return RerankResponse(data=rerank_data)
|
||||||
|
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e
|
||||||
|
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue