This commit is contained in:
Jiayi Ni 2025-10-03 12:27:16 -07:00 committed by GitHub
commit 1e04f105f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 840 additions and 34 deletions

View file

@ -1,9 +1,10 @@
--- ---
description: "Llama Stack Inference API for generating completions, chat completions, and embeddings. description: "Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Three kinds of models are supported:
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. - LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search." - Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models: these models reorder the documents based on their relevance to a query."
sidebar_label: Inference sidebar_label: Inference
title: Inference title: Inference
--- ---
@ -14,8 +15,9 @@ title: Inference
Llama Stack Inference API for generating completions, chat completions, and embeddings. Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Three kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search. - Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models: these models reorder the documents based on their relevance to a query.
This section contains documentation for all available providers for the **inference** API. This section contains documentation for all available providers for the **inference** API.

View file

@ -13335,7 +13335,7 @@
}, },
{ {
"name": "Inference", "name": "Inference",
"description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", "description": "This API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.",
"x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings." "x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings."
}, },
{ {

View file

@ -9990,13 +9990,16 @@ tags:
description: '' description: ''
- name: Inference - name: Inference
description: >- description: >-
This API provides the raw interface to the underlying models. Two kinds of models This API provides the raw interface to the underlying models. Three kinds of
are supported: models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic - Embedding models: these models generate embeddings to be used for semantic
search. search.
- Rerank models: these models reorder the documents based on their relevance
to a query.
x-displayName: >- x-displayName: >-
Llama Stack Inference API for generating completions, chat completions, and Llama Stack Inference API for generating completions, chat completions, and
embeddings. embeddings.

View file

@ -4992,7 +4992,7 @@
"properties": { "properties": {
"model": { "model": {
"type": "string", "type": "string",
"description": "The identifier of the reranking model to use." "description": "The identifier of the reranking model to use. The model must be a reranking model registered with Llama Stack and available via the /models endpoint."
}, },
"query": { "query": {
"oneOf": [ "oneOf": [

View file

@ -3657,7 +3657,8 @@ components:
model: model:
type: string type: string
description: >- description: >-
The identifier of the reranking model to use. The identifier of the reranking model to use. The model must be a reranking
model registered with Llama Stack and available via the /models endpoint.
query: query:
oneOf: oneOf:
- type: string - type: string

View file

@ -6829,7 +6829,8 @@
"type": "string", "type": "string",
"enum": [ "enum": [
"llm", "llm",
"embedding" "embedding",
"rerank"
], ],
"title": "ModelType", "title": "ModelType",
"description": "Enumeration of supported model types in Llama Stack." "description": "Enumeration of supported model types in Llama Stack."
@ -12883,7 +12884,7 @@
}, },
{ {
"name": "Inference", "name": "Inference",
"description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", "description": "This API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.",
"x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings." "x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings."
}, },
{ {

View file

@ -5158,6 +5158,7 @@ components:
enum: enum:
- llm - llm
- embedding - embedding
- rerank
title: ModelType title: ModelType
description: >- description: >-
Enumeration of supported model types in Llama Stack. Enumeration of supported model types in Llama Stack.
@ -9728,13 +9729,16 @@ tags:
description: '' description: ''
- name: Inference - name: Inference
description: >- description: >-
This API provides the raw interface to the underlying models. Two kinds of models This API provides the raw interface to the underlying models. Three kinds of
are supported: models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic - Embedding models: these models generate embeddings to be used for semantic
search. search.
- Rerank models: these models reorder the documents based on their relevance
to a query.
x-displayName: >- x-displayName: >-
Llama Stack Inference API for generating completions, chat completions, and Llama Stack Inference API for generating completions, chat completions, and
embeddings. embeddings.

View file

@ -8838,7 +8838,8 @@
"type": "string", "type": "string",
"enum": [ "enum": [
"llm", "llm",
"embedding" "embedding",
"rerank"
], ],
"title": "ModelType", "title": "ModelType",
"description": "Enumeration of supported model types in Llama Stack." "description": "Enumeration of supported model types in Llama Stack."
@ -17033,7 +17034,7 @@
"properties": { "properties": {
"model": { "model": {
"type": "string", "type": "string",
"description": "The identifier of the reranking model to use." "description": "The identifier of the reranking model to use. The model must be a reranking model registered with Llama Stack and available via the /models endpoint."
}, },
"query": { "query": {
"oneOf": [ "oneOf": [
@ -18456,7 +18457,7 @@
}, },
{ {
"name": "Inference", "name": "Inference",
"description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", "description": "This API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.",
"x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings." "x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings."
}, },
{ {

View file

@ -6603,6 +6603,7 @@ components:
enum: enum:
- llm - llm
- embedding - embedding
- rerank
title: ModelType title: ModelType
description: >- description: >-
Enumeration of supported model types in Llama Stack. Enumeration of supported model types in Llama Stack.
@ -12693,7 +12694,8 @@ components:
model: model:
type: string type: string
description: >- description: >-
The identifier of the reranking model to use. The identifier of the reranking model to use. The model must be a reranking
model registered with Llama Stack and available via the /models endpoint.
query: query:
oneOf: oneOf:
- type: string - type: string
@ -13774,13 +13776,16 @@ tags:
description: '' description: ''
- name: Inference - name: Inference
description: >- description: >-
This API provides the raw interface to the underlying models. Two kinds of models This API provides the raw interface to the underlying models. Three kinds of
are supported: models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic - Embedding models: these models generate embeddings to be used for semantic
search. search.
- Rerank models: these models reorder the documents based on their relevance
to a query.
x-displayName: >- x-displayName: >-
Llama Stack Inference API for generating completions, chat completions, and Llama Stack Inference API for generating completions, chat completions, and
embeddings. embeddings.

View file

@ -1016,7 +1016,7 @@ class InferenceProvider(Protocol):
) -> RerankResponse: ) -> RerankResponse:
"""Rerank a list of documents based on their relevance to a query. """Rerank a list of documents based on their relevance to a query.
:param model: The identifier of the reranking model to use. :param model: The identifier of the reranking model to use. The model must be a reranking model registered with Llama Stack and available via the /models endpoint.
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length. :param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length. :param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all. :param max_num_results: (Optional) Maximum number of results to return. Default: returns all.
@ -1159,9 +1159,10 @@ class InferenceProvider(Protocol):
class Inference(InferenceProvider): class Inference(InferenceProvider):
"""Llama Stack Inference API for generating completions, chat completions, and embeddings. """Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported: This API provides the raw interface to the underlying models. Three kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions. - LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search. - Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models: these models reorder the documents based on their relevance to a query.
""" """
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True) @webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)

View file

@ -27,10 +27,12 @@ class ModelType(StrEnum):
"""Enumeration of supported model types in Llama Stack. """Enumeration of supported model types in Llama Stack.
:cvar llm: Large language model for text generation and completion :cvar llm: Large language model for text generation and completion
:cvar embedding: Embedding model for converting text to vector representations :cvar embedding: Embedding model for converting text to vector representations
:cvar rerank: Reranking model for reordering documents based on their relevance to a query
""" """
llm = "llm" llm = "llm"
embedding = "embedding" embedding = "embedding"
rerank = "rerank"
@json_schema_type @json_schema_type

View file

@ -41,9 +41,14 @@ from llama_stack.apis.inference import (
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
Order, Order,
RerankResponse,
StopReason, StopReason,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -179,6 +184,23 @@ class InferenceRouter(Inference):
raise ModelTypeError(model_id, model.model_type, expected_model_type) raise ModelTypeError(model_id, model.model_type, expected_model_type)
return model return model
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
logger.debug(f"InferenceRouter.rerank: {model}")
model_obj = await self._get_model(model, ModelType.rerank)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
return await provider.rerank(
model=model_obj.identifier,
query=query,
items=items,
max_num_results=max_num_results,
)
async def openai_completion( async def openai_completion(
self, self,
model: str, model: str,

View file

@ -188,3 +188,22 @@ vlm_response = client.chat.completions.create(
print(f"VLM Response: {vlm_response.choices[0].message.content}") print(f"VLM Response: {vlm_response.choices[0].message.content}")
``` ```
### Rerank Example
The following example shows how to rerank documents using an NVIDIA NIM.
```python
rerank_response = client.inference.rerank(
model="nvidia/llama-3.2-nv-rerankqa-1b-v2",
query="query",
items=[
"item_1",
"item_2",
"item_3",
],
)
for i, result in enumerate(rerank_response):
print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]")
```

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import aiohttp
from openai import NOT_GIVEN from openai import NOT_GIVEN
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -12,7 +13,14 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingData, OpenAIEmbeddingData,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage, OpenAIEmbeddingUsage,
RerankData,
RerankResponse,
) )
from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -44,6 +52,18 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024}, "snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
} }
rerank_model_list = [
"nv-rerank-qa-mistral-4b:1",
"nvidia/nv-rerankqa-mistral-4b-v3",
"nvidia/llama-3.2-nv-rerankqa-1b-v2",
]
_rerank_model_endpoints = {
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
"nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
"nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
}
def __init__(self, config: NVIDIAConfig) -> None: def __init__(self, config: NVIDIAConfig) -> None:
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...") logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
@ -62,6 +82,8 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
# "Consider removing the api_key from the configuration." # "Consider removing the api_key from the configuration."
# ) # )
super().__init__()
self._config = config self._config = config
def get_api_key(self) -> str: def get_api_key(self) -> str:
@ -80,6 +102,103 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
""" """
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
async def list_models(self) -> list[Model] | None:
"""
List available NVIDIA models by combining:
1. Dynamic models from https://integrate.api.nvidia.com/v1/models
2. Static rerank models (which use different API endpoints)
"""
self._model_cache = {}
models = await super().list_models()
# Add rerank models
existing_ids = {m.identifier for m in models}
for model_id, _ in self._rerank_model_endpoints.items():
if self.allowed_models and model_id not in self.allowed_models:
continue
if model_id not in existing_ids:
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=model_id,
identifier=model_id,
model_type=ModelType.rerank,
)
models.append(model)
self._model_cache[model_id] = model
return models
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
provider_model_id = await self._get_provider_model_id(model)
ranking_url = self.get_base_url()
if _is_nvidia_hosted(self._config) and provider_model_id in self._rerank_model_endpoints:
ranking_url = self._rerank_model_endpoints[provider_model_id]
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
# Convert query to text format
if isinstance(query, str):
query_text = query
elif isinstance(query, OpenAIChatCompletionContentPartTextParam):
query_text = query.text
else:
raise ValueError("Query must be a string or text content part")
# Convert items to text format
passages = []
for item in items:
if isinstance(item, str):
passages.append({"text": item})
elif isinstance(item, OpenAIChatCompletionContentPartTextParam):
passages.append({"text": item.text})
else:
raise ValueError("Items must be strings or text content parts")
payload = {
"model": provider_model_id,
"query": {"text": query_text},
"passages": passages,
}
headers = {
"Authorization": f"Bearer {self.get_api_key()}",
"Content-Type": "application/json",
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(ranking_url, headers=headers, json=payload) as response:
if response.status != 200:
response_text = await response.text()
raise ConnectionError(
f"NVIDIA rerank API request failed with status {response.status}: {response_text}"
)
result = await response.json()
rankings = result.get("rankings", [])
# Convert to RerankData format
rerank_data = []
for ranking in rankings:
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
# Apply max_num_results limit
if max_num_results is not None:
rerank_data = rerank_data[:max_num_results]
return RerankResponse(data=rerank_data)
except aiohttp.ClientError as e:
raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e
async def openai_embeddings( async def openai_embeddings(
self, self,
model: str, model: str,

View file

@ -63,6 +63,10 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}} # Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
embedding_model_metadata: dict[str, dict[str, int]] = {} embedding_model_metadata: dict[str, dict[str, int]] = {}
# List of rerank model IDs for this provider
# Can be set by subclasses or instances to provide rerank models
rerank_model_list: list[str] = []
# Cache of available models keyed by model ID # Cache of available models keyed by model ID
# This is set in list_models() and used in check_model_availability() # This is set in list_models() and used in check_model_availability()
_model_cache: dict[str, Model] = {} _model_cache: dict[str, Model] = {}
@ -400,6 +404,14 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
model_type=ModelType.embedding, model_type=ModelType.embedding,
metadata=metadata, metadata=metadata,
) )
elif m.id in self.rerank_model_list:
# This is a rerank model
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.rerank,
)
else: else:
# This is an LLM # This is an LLM
model = Model( model = Model(

View file

@ -120,6 +120,10 @@ def pytest_addoption(parser):
"--embedding-model", "--embedding-model",
help="comma-separated list of embedding models. Fixture name: embedding_model_id", help="comma-separated list of embedding models. Fixture name: embedding_model_id",
) )
parser.addoption(
"--rerank-model",
help="comma-separated list of rerank models. Fixture name: rerank_model_id",
)
parser.addoption( parser.addoption(
"--safety-shield", "--safety-shield",
help="comma-separated list of safety shields. Fixture name: shield_id", help="comma-separated list of safety shields. Fixture name: shield_id",
@ -198,6 +202,7 @@ def pytest_generate_tests(metafunc):
"shield_id": ("--safety-shield", "shield"), "shield_id": ("--safety-shield", "shield"),
"judge_model_id": ("--judge-model", "judge"), "judge_model_id": ("--judge-model", "judge"),
"embedding_dimension": ("--embedding-dimension", "dim"), "embedding_dimension": ("--embedding-dimension", "dim"),
"rerank_model_id": ("--rerank-model", "rerank"),
} }
# Collect all parameters and their values # Collect all parameters and their values

View file

@ -119,6 +119,7 @@ def client_with_models(
embedding_model_id, embedding_model_id,
embedding_dimension, embedding_dimension,
judge_model_id, judge_model_id,
rerank_model_id,
): ):
client = llama_stack_client client = llama_stack_client
@ -151,6 +152,20 @@ def client_with_models(
model_type="embedding", model_type="embedding",
metadata={"embedding_dimension": embedding_dimension or 384}, metadata={"embedding_dimension": embedding_dimension or 384},
) )
if rerank_model_id and rerank_model_id not in model_ids:
selected_provider = None
for p in providers:
# Currently only NVIDIA inference provider supports reranking
if p.provider_type == "remote::nvidia":
selected_provider = p
break
selected_provider = selected_provider or providers[0]
client.models.register(
model_id=rerank_model_id,
provider_id=selected_provider.provider_id,
model_type="rerank",
)
return client return client
@ -166,7 +181,14 @@ def model_providers(llama_stack_client):
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def skip_if_no_model(request): def skip_if_no_model(request):
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"] model_fixtures = [
"text_model_id",
"vision_model_id",
"embedding_model_id",
"judge_model_id",
"shield_id",
"rerank_model_id",
]
test_func = request.node.function test_func = request.node.function
actual_params = inspect.signature(test_func).parameters.keys() actual_params = inspect.signature(test_func).parameters.keys()

View file

@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
from llama_stack_client.types.alpha import InferenceRerankResponse
from llama_stack_client.types.shared.interleaved_content import (
ImageContentItem,
ImageContentItemImage,
ImageContentItemImageURL,
TextContentItem,
)
from llama_stack.core.library_client import LlamaStackAsLibraryClient
# Test data
DUMMY_STRING = "string_1"
DUMMY_STRING2 = "string_2"
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
DUMMY_IMAGE_URL = ImageContentItem(
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
)
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
def skip_if_provider_doesnt_support_rerank(inference_provider_type):
supported_providers = {"remote::nvidia"}
if inference_provider_type not in supported_providers:
pytest.skip(f"{inference_provider_type} doesn't support rerank models")
def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None:
"""
Validate that a rerank response has the correct structure and ordering.
Args:
response: The InferenceRerankResponse to validate
items: The original items list that was ranked
Raises:
AssertionError: If any validation fails
"""
seen = set()
last_score = float("inf")
for d in response:
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
assert d.index not in seen, f"Duplicate index {d.index} found"
seen.add(d.index)
assert isinstance(d.relevance_score, float), f"Score must be float, got {type(d.relevance_score)}"
assert d.relevance_score <= last_score, f"Scores not in descending order: {d.relevance_score} > {last_score}"
last_score = d.relevance_score
def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, expected_first_item: str) -> None:
"""
Validate that the expected most relevant item ranks first.
Args:
response: The InferenceRerankResponse to validate
items: The original items list that was ranked
expected_first_item: The expected first item in the ranking
Raises:
AssertionError: If any validation fails
"""
if not response:
raise AssertionError("No ranking data returned in response")
actual_first_index = response[0].index
actual_first_item = items[actual_first_index]
assert actual_first_item == expected_first_item, (
f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead."
)
@pytest.mark.parametrize(
"query,items",
[
(DUMMY_STRING, [DUMMY_STRING, DUMMY_STRING2]),
(DUMMY_TEXT, [DUMMY_TEXT, DUMMY_TEXT2]),
(DUMMY_STRING, [DUMMY_STRING2, DUMMY_TEXT]),
(DUMMY_TEXT, [DUMMY_STRING, DUMMY_TEXT2]),
],
ids=[
"string-query-string-items",
"text-query-text-items",
"mixed-content-1",
"mixed-content-2",
],
)
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, list)
# TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client.
assert len(response) <= len(items)
_validate_rerank_response(response, items)
@pytest.mark.parametrize(
"query,items",
[
(DUMMY_IMAGE_URL, [DUMMY_STRING]),
(DUMMY_IMAGE_BASE64, [DUMMY_TEXT]),
(DUMMY_TEXT, [DUMMY_IMAGE_URL]),
(DUMMY_IMAGE_BASE64, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
(DUMMY_TEXT, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
],
ids=[
"image-query-url",
"image-query-base64",
"text-query-image-item",
"mixed-content-1",
"mixed-content-2",
],
)
def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
error_type = (
ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
)
with pytest.raises(error_type):
client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
else:
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, list)
assert len(response) <= len(items)
_validate_rerank_response(response, items)
def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
max_num_results = 2
response = client_with_models.alpha.inference.rerank(
model=rerank_model_id,
query=DUMMY_STRING,
items=items,
max_num_results=max_num_results,
)
assert isinstance(response, list)
assert len(response) == max_num_results
_validate_rerank_response(response, items)
def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
items = [DUMMY_STRING, DUMMY_STRING2]
response = client_with_models.alpha.inference.rerank(
model=rerank_model_id,
query=DUMMY_STRING,
items=items,
max_num_results=10, # Larger than items length
)
assert isinstance(response, list)
assert len(response) <= len(items) # Should return at most len(items)
@pytest.mark.parametrize(
"query,items,expected_first_item",
[
(
"What is a reranking model? ",
[
"A reranking model reranks a list of items based on the query. ",
"Machine learning algorithms learn patterns from data. ",
"Python is a programming language. ",
],
"A reranking model reranks a list of items based on the query. ",
),
(
"What is C++?",
[
"Learning new things is interesting. ",
"C++ is a programming language. ",
"Books provide knowledge and entertainment. ",
],
"C++ is a programming language. ",
),
(
"What are good learning habits? ",
[
"Cooking pasta is a fun activity. ",
"Plants need water and sunlight. ",
"Good learning habits include reading daily and taking notes. ",
],
"Good learning habits include reading daily and taking notes. ",
),
],
)
def test_rerank_semantic_correctness(
client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type
):
skip_if_provider_doesnt_support_rerank(inference_provider_type)
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
_validate_rerank_response(response, items)
_validate_semantic_ranking(response, items, expected_first_item)

View file

@ -0,0 +1,222 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
import pytest
from llama_stack.apis.models import ModelType
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
class MockResponse:
def __init__(self, status=200, json_data=None, text_data="OK"):
self.status = status
self._json_data = json_data or {"rankings": []}
self._text_data = text_data
async def json(self):
return self._json_data
async def text(self):
return self._text_data
class MockSession:
def __init__(self, response):
self.response = response
self.post_calls = []
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return False
def post(self, url, **kwargs):
self.post_calls.append((url, kwargs))
class PostContext:
def __init__(self, response):
self.response = response
async def __aenter__(self):
return self.response
async def __aexit__(self, exc_type, exc_val, exc_tb):
return False
return PostContext(self.response)
def create_adapter(config=None, rerank_endpoints=None):
if config is None:
config = NVIDIAConfig(api_key="test-key")
adapter = NVIDIAInferenceAdapter(config)
class MockModel:
provider_resource_id = "test-model"
metadata = {}
adapter.model_store = AsyncMock()
adapter.model_store.get_model = AsyncMock(return_value=MockModel())
if rerank_endpoints is not None:
adapter._rerank_model_endpoints = rerank_endpoints
return adapter
async def test_rerank_basic_functionality():
adapter = create_adapter()
mock_response = MockResponse(json_data={"rankings": [{"index": 0, "logit": 0.5}]})
mock_session = MockSession(mock_response)
with patch("aiohttp.ClientSession", return_value=mock_session):
result = await adapter.rerank(model="test-model", query="test query", items=["item1", "item2"])
assert len(result.data) == 1
assert result.data[0].index == 0
assert result.data[0].relevance_score == 0.5
url, kwargs = mock_session.post_calls[0]
payload = kwargs["json"]
assert payload["model"] == "test-model"
assert payload["query"] == {"text": "test query"}
assert payload["passages"] == [{"text": "item1"}, {"text": "item2"}]
async def test_missing_rankings_key():
adapter = create_adapter()
mock_session = MockSession(MockResponse(json_data={}))
with patch("aiohttp.ClientSession", return_value=mock_session):
result = await adapter.rerank(model="test-model", query="q", items=["a"])
assert len(result.data) == 0
async def test_hosted_with_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"test-model": "https://model.endpoint/rerank"}
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert url == "https://model.endpoint/rerank"
async def test_hosted_without_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com).
rerank_endpoints={}, # No endpoint mapping for test-model
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert "https://integrate.api.nvidia.com" in url
async def test_hosted_model_not_in_endpoint_mapping():
adapter = create_adapter(
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"other-model": "https://other.endpoint/rerank"}
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert "https://integrate.api.nvidia.com" in url
assert url != "https://other.endpoint/rerank"
async def test_self_hosted_ignores_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
)
mock_session = MockSession(MockResponse())
with patch("aiohttp.ClientSession", return_value=mock_session):
await adapter.rerank(model="test-model", query="q", items=["a"])
url, _ = mock_session.post_calls[0]
assert "http://localhost:8000" in url
assert "model.endpoint/rerank" not in url
async def test_max_num_results():
adapter = create_adapter()
rankings = [{"index": 0, "logit": 0.8}, {"index": 1, "logit": 0.6}]
mock_session = MockSession(MockResponse(json_data={"rankings": rankings}))
with patch("aiohttp.ClientSession", return_value=mock_session):
result = await adapter.rerank(model="test-model", query="q", items=["a", "b"], max_num_results=1)
assert len(result.data) == 1
assert result.data[0].index == 0
assert result.data[0].relevance_score == 0.8
async def test_http_error():
adapter = create_adapter()
mock_session = MockSession(MockResponse(status=500, text_data="Server Error"))
with patch("aiohttp.ClientSession", return_value=mock_session):
with pytest.raises(ConnectionError, match="status 500.*Server Error"):
await adapter.rerank(model="test-model", query="q", items=["a"])
async def test_client_error():
adapter = create_adapter()
mock_session = AsyncMock()
mock_session.__aenter__.side_effect = aiohttp.ClientError("Network error")
with patch("aiohttp.ClientSession", return_value=mock_session):
with pytest.raises(ConnectionError, match="Failed to connect.*Network error"):
await adapter.rerank(model="test-model", query="q", items=["a"])
async def test_list_models_adds_rerank_models():
"""Test that list_models adds rerank models to the dynamic model list."""
adapter = create_adapter()
adapter.__provider_id__ = "nvidia"
# Mock the list_models from the superclass to return some dynamic models
base_models = [
MagicMock(identifier="llm-1", model_type=ModelType.llm),
MagicMock(identifier="embedding-1", model_type=ModelType.embedding),
]
with patch.object(NVIDIAInferenceAdapter.__bases__[0], "list_models", return_value=base_models):
result = await adapter.list_models()
assert result is not None
# Check that the rerank models are added
model_ids = [m.identifier for m in result]
assert "nv-rerank-qa-mistral-4b:1" in model_ids
assert "nvidia/nv-rerankqa-mistral-4b-v3" in model_ids
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in model_ids
rerank_models = [m for m in result if m.model_type == ModelType.rerank]
assert len(rerank_models) == 3
for rerank_model in rerank_models:
assert rerank_model.provider_id == "nvidia"
assert rerank_model.metadata == {}
assert rerank_model.identifier in adapter._model_cache

View file

@ -35,6 +35,40 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
} }
class OpenAIMixinWithRerankImpl(OpenAIMixin):
"""Test implementation with rerank model list"""
rerank_model_list = ["rerank-model-1", "rerank-model-2"]
def __init__(self):
self.__provider_id__ = "test-provider"
def get_api_key(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
def get_base_url(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
class OpenAIMixinWithEmbeddingsAndRerankImpl(OpenAIMixin):
"""Test implementation with both embedding model metadata and rerank model list"""
embedding_model_metadata = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
}
rerank_model_list = ["rerank-model-1", "rerank-model-2"]
__provider_id__ = "test-provider"
def get_api_key(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
def get_base_url(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
@pytest.fixture @pytest.fixture
def mixin(): def mixin():
"""Create a test instance of OpenAIMixin with mocked model_store""" """Create a test instance of OpenAIMixin with mocked model_store"""
@ -56,6 +90,18 @@ def mixin_with_embeddings():
return OpenAIMixinWithEmbeddingsImpl() return OpenAIMixinWithEmbeddingsImpl()
@pytest.fixture
def mixin_with_rerank():
"""Create a test instance of OpenAIMixin with rerank model list"""
return OpenAIMixinWithRerankImpl()
@pytest.fixture
def mixin_with_embeddings_and_rerank():
"""Create a test instance of OpenAIMixin with both embedding model metadata and rerank model list"""
return OpenAIMixinWithEmbeddingsAndRerankImpl()
@pytest.fixture @pytest.fixture
def mock_models(): def mock_models():
"""Create multiple mock OpenAI model objects""" """Create multiple mock OpenAI model objects"""
@ -107,6 +153,19 @@ def mock_client_context():
return _mock_client_context return _mock_client_context
def _assert_models_match_expected(actual_models, expected_models):
"""Verify the models match expected attributes.
Args:
actual_models: List of models to verify
expected_models: Mapping of model identifier to expected attribute values
"""
for identifier, expected_attrs in expected_models.items():
model = next(m for m in actual_models if m.identifier == identifier)
for attr_name, expected_value in expected_attrs.items():
assert getattr(model, attr_name) == expected_value
class TestOpenAIMixinListModels: class TestOpenAIMixinListModels:
"""Test cases for the list_models method""" """Test cases for the list_models method"""
@ -300,21 +359,113 @@ class TestOpenAIMixinEmbeddingModelMetadata:
assert result is not None assert result is not None
assert len(result) == 2 assert len(result) == 2
# Find the models in the result expected_models = {
embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small") "text-embedding-3-small": {
llm_model = next(m for m in result if m.identifier == "gpt-4") "model_type": ModelType.embedding,
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
"provider_id": "test-provider",
"provider_resource_id": "text-embedding-3-small",
},
"gpt-4": {
"model_type": ModelType.llm,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "gpt-4",
},
}
# Check embedding model _assert_models_match_expected(result, expected_models)
assert embedding_model.model_type == ModelType.embedding
assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192}
assert embedding_model.provider_id == "test-provider"
assert embedding_model.provider_resource_id == "text-embedding-3-small"
# Check LLM model
assert llm_model.model_type == ModelType.llm class TestOpenAIMixinRerankModelList:
assert llm_model.metadata == {} # No metadata for LLMs """Test cases for rerank_model_list attribute functionality"""
assert llm_model.provider_id == "test-provider"
assert llm_model.provider_resource_id == "gpt-4" async def test_rerank_model_identified(self, mixin_with_rerank, mock_client_context):
"""Test that models in rerank_model_list are correctly identified as rerank models"""
# Create mock models: 1 rerank model and 1 LLM
mock_rerank_model = MagicMock(id="rerank-model-1")
mock_llm_model = MagicMock(id="gpt-4")
mock_models = [mock_rerank_model, mock_llm_model]
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
with mock_client_context(mixin_with_rerank, mock_client):
result = await mixin_with_rerank.list_models()
assert result is not None
assert len(result) == 2
expected_models = {
"rerank-model-1": {
"model_type": ModelType.rerank,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "rerank-model-1",
},
"gpt-4": {
"model_type": ModelType.llm,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "gpt-4",
},
}
_assert_models_match_expected(result, expected_models)
class TestOpenAIMixinMixedModelTypes:
"""Test cases for mixed model types (LLM, embedding, rerank)"""
async def test_mixed_model_types_identification(self, mixin_with_embeddings_and_rerank, mock_client_context):
"""Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata"""
# Create mock models: 1 embedding, 1 rerank, 1 LLM
mock_embedding_model = MagicMock(id="text-embedding-3-small")
mock_rerank_model = MagicMock(id="rerank-model-1")
mock_llm_model = MagicMock(id="gpt-4")
mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model]
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
with mock_client_context(mixin_with_embeddings_and_rerank, mock_client):
result = await mixin_with_embeddings_and_rerank.list_models()
assert result is not None
assert len(result) == 3
expected_models = {
"text-embedding-3-small": {
"model_type": ModelType.embedding,
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
"provider_id": "test-provider",
"provider_resource_id": "text-embedding-3-small",
},
"rerank-model-1": {
"model_type": ModelType.rerank,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "rerank-model-1",
},
"gpt-4": {
"model_type": ModelType.llm,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "gpt-4",
},
}
_assert_models_match_expected(result, expected_models)
class TestOpenAIMixinAllowedModels: class TestOpenAIMixinAllowedModels: