diff --git a/docs/docs/providers/inference/index.mdx b/docs/docs/providers/inference/index.mdx index ebbaf1be1..63741f202 100644 --- a/docs/docs/providers/inference/index.mdx +++ b/docs/docs/providers/inference/index.mdx @@ -1,9 +1,10 @@ --- description: "Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: + This API provides the raw interface to the underlying models. Three kinds of models are supported: - LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. - - Embedding models: these models generate embeddings to be used for semantic search." + - Embedding models: these models generate embeddings to be used for semantic search. + - Rerank models: these models reorder the documents based on their relevance to a query." sidebar_label: Inference title: Inference --- @@ -14,8 +15,9 @@ title: Inference Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: + This API provides the raw interface to the underlying models. Three kinds of models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + - Rerank models: these models reorder the documents based on their relevance to a query. This section contains documentation for all available providers for the **inference** API. diff --git a/docs/static/deprecated-llama-stack-spec.html b/docs/static/deprecated-llama-stack-spec.html index 7edfe3f5d..f0dd903a6 100644 --- a/docs/static/deprecated-llama-stack-spec.html +++ b/docs/static/deprecated-llama-stack-spec.html @@ -13335,7 +13335,7 @@ }, { "name": "Inference", - "description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", + "description": "This API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.", "x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings." }, { diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index ca832d46b..48863025f 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -9990,13 +9990,16 @@ tags: description: '' - name: Inference description: >- - This API provides the raw interface to the underlying models. Two kinds of models - are supported: + This API provides the raw interface to the underlying models. Three kinds of + models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + + - Rerank models: these models reorder the documents based on their relevance + to a query. x-displayName: >- Llama Stack Inference API for generating completions, chat completions, and embeddings. diff --git a/docs/static/experimental-llama-stack-spec.html b/docs/static/experimental-llama-stack-spec.html index a84226c05..574107a6d 100644 --- a/docs/static/experimental-llama-stack-spec.html +++ b/docs/static/experimental-llama-stack-spec.html @@ -4992,7 +4992,7 @@ "properties": { "model": { "type": "string", - "description": "The identifier of the reranking model to use." + "description": "The identifier of the reranking model to use. The model must be a reranking model registered with Llama Stack and available via the /models endpoint." }, "query": { "oneOf": [ diff --git a/docs/static/experimental-llama-stack-spec.yaml b/docs/static/experimental-llama-stack-spec.yaml index a08c0cc87..aae356d6d 100644 --- a/docs/static/experimental-llama-stack-spec.yaml +++ b/docs/static/experimental-llama-stack-spec.yaml @@ -3657,7 +3657,8 @@ components: model: type: string description: >- - The identifier of the reranking model to use. + The identifier of the reranking model to use. The model must be a reranking + model registered with Llama Stack and available via the /models endpoint. query: oneOf: - type: string diff --git a/docs/static/llama-stack-spec.html b/docs/static/llama-stack-spec.html index 96e97035f..2ee665123 100644 --- a/docs/static/llama-stack-spec.html +++ b/docs/static/llama-stack-spec.html @@ -6829,7 +6829,8 @@ "type": "string", "enum": [ "llm", - "embedding" + "embedding", + "rerank" ], "title": "ModelType", "description": "Enumeration of supported model types in Llama Stack." @@ -12883,7 +12884,7 @@ }, { "name": "Inference", - "description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", + "description": "This API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.", "x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings." }, { diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index b9e03d614..566ac7de9 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -5158,6 +5158,7 @@ components: enum: - llm - embedding + - rerank title: ModelType description: >- Enumeration of supported model types in Llama Stack. @@ -9728,13 +9729,16 @@ tags: description: '' - name: Inference description: >- - This API provides the raw interface to the underlying models. Two kinds of models - are supported: + This API provides the raw interface to the underlying models. Three kinds of + models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + + - Rerank models: these models reorder the documents based on their relevance + to a query. x-displayName: >- Llama Stack Inference API for generating completions, chat completions, and embeddings. diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html index 7ec48ef74..6bc67536d 100644 --- a/docs/static/stainless-llama-stack-spec.html +++ b/docs/static/stainless-llama-stack-spec.html @@ -8838,7 +8838,8 @@ "type": "string", "enum": [ "llm", - "embedding" + "embedding", + "rerank" ], "title": "ModelType", "description": "Enumeration of supported model types in Llama Stack." @@ -17033,7 +17034,7 @@ "properties": { "model": { "type": "string", - "description": "The identifier of the reranking model to use." + "description": "The identifier of the reranking model to use. The model must be a reranking model registered with Llama Stack and available via the /models endpoint." }, "query": { "oneOf": [ @@ -18456,7 +18457,7 @@ }, { "name": "Inference", - "description": "This API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", + "description": "This API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models: these models reorder the documents based on their relevance to a query.", "x-displayName": "Llama Stack Inference API for generating completions, chat completions, and embeddings." }, { diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 3bede159b..8fc70a5cd 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -6603,6 +6603,7 @@ components: enum: - llm - embedding + - rerank title: ModelType description: >- Enumeration of supported model types in Llama Stack. @@ -12693,7 +12694,8 @@ components: model: type: string description: >- - The identifier of the reranking model to use. + The identifier of the reranking model to use. The model must be a reranking + model registered with Llama Stack and available via the /models endpoint. query: oneOf: - type: string @@ -13774,13 +13776,16 @@ tags: description: '' - name: Inference description: >- - This API provides the raw interface to the underlying models. Two kinds of models - are supported: + This API provides the raw interface to the underlying models. Three kinds of + models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + + - Rerank models: these models reorder the documents based on their relevance + to a query. x-displayName: >- Llama Stack Inference API for generating completions, chat completions, and embeddings. diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index e88a16315..6260ba552 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -1016,7 +1016,7 @@ class InferenceProvider(Protocol): ) -> RerankResponse: """Rerank a list of documents based on their relevance to a query. - :param model: The identifier of the reranking model to use. + :param model: The identifier of the reranking model to use. The model must be a reranking model registered with Llama Stack and available via the /models endpoint. :param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length. :param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length. :param max_num_results: (Optional) Maximum number of results to return. Default: returns all. @@ -1159,9 +1159,10 @@ class InferenceProvider(Protocol): class Inference(InferenceProvider): """Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: + This API provides the raw interface to the underlying models. Three kinds of models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + - Rerank models: these models reorder the documents based on their relevance to a query. """ @webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True) diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 210ed9246..1275e90e3 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -27,10 +27,12 @@ class ModelType(StrEnum): """Enumeration of supported model types in Llama Stack. :cvar llm: Large language model for text generation and completion :cvar embedding: Embedding model for converting text to vector representations + :cvar rerank: Reranking model for reordering documents based on their relevance to a query """ llm = "llm" embedding = "embedding" + rerank = "rerank" @json_schema_type diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index c4338e614..fcc16332f 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -41,9 +41,14 @@ from llama_stack.apis.inference import ( OpenAIMessageParam, OpenAIResponseFormatParam, Order, + RerankResponse, StopReason, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, +) from llama_stack.apis.models import Model, ModelType from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.log import get_logger @@ -179,6 +184,23 @@ class InferenceRouter(Inference): raise ModelTypeError(model_id, model.model_type, expected_model_type) return model + async def rerank( + self, + model: str, + query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + max_num_results: int | None = None, + ) -> RerankResponse: + logger.debug(f"InferenceRouter.rerank: {model}") + model_obj = await self._get_model(model, ModelType.rerank) + provider = await self.routing_table.get_provider_impl(model_obj.identifier) + return await provider.rerank( + model=model_obj.identifier, + query=query, + items=items, + max_num_results=max_num_results, + ) + async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/nvidia/NVIDIA.md b/llama_stack/providers/remote/inference/nvidia/NVIDIA.md index 625be6088..dcc9d3909 100644 --- a/llama_stack/providers/remote/inference/nvidia/NVIDIA.md +++ b/llama_stack/providers/remote/inference/nvidia/NVIDIA.md @@ -188,3 +188,22 @@ vlm_response = client.chat.completions.create( print(f"VLM Response: {vlm_response.choices[0].message.content}") ``` + +### Rerank Example + +The following example shows how to rerank documents using an NVIDIA NIM. + +```python +rerank_response = client.inference.rerank( + model="nvidia/llama-3.2-nv-rerankqa-1b-v2", + query="query", + items=[ + "item_1", + "item_2", + "item_3", + ], +) + +for i, result in enumerate(rerank_response): + print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]") +``` \ No newline at end of file diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 2e6c3d769..15e50ff97 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -5,6 +5,7 @@ # the root directory of this source tree. +import aiohttp from openai import NOT_GIVEN from llama_stack.apis.inference import ( @@ -12,7 +13,14 @@ from llama_stack.apis.inference import ( OpenAIEmbeddingData, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, + RerankData, + RerankResponse, ) +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, +) +from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin @@ -44,6 +52,18 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference): "snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024}, } + rerank_model_list = [ + "nv-rerank-qa-mistral-4b:1", + "nvidia/nv-rerankqa-mistral-4b-v3", + "nvidia/llama-3.2-nv-rerankqa-1b-v2", + ] + + _rerank_model_endpoints = { + "nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking", + "nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking", + "nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking", + } + def __init__(self, config: NVIDIAConfig) -> None: logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...") @@ -62,6 +82,8 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference): # "Consider removing the api_key from the configuration." # ) + super().__init__() + self._config = config def get_api_key(self) -> str: @@ -80,6 +102,103 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference): """ return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url + async def list_models(self) -> list[Model] | None: + """ + List available NVIDIA models by combining: + 1. Dynamic models from https://integrate.api.nvidia.com/v1/models + 2. Static rerank models (which use different API endpoints) + """ + self._model_cache = {} + models = await super().list_models() + + # Add rerank models + existing_ids = {m.identifier for m in models} + for model_id, _ in self._rerank_model_endpoints.items(): + if self.allowed_models and model_id not in self.allowed_models: + continue + if model_id not in existing_ids: + model = Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=model_id, + identifier=model_id, + model_type=ModelType.rerank, + ) + models.append(model) + self._model_cache[model_id] = model + + return models + + async def rerank( + self, + model: str, + query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + max_num_results: int | None = None, + ) -> RerankResponse: + provider_model_id = await self._get_provider_model_id(model) + + ranking_url = self.get_base_url() + + if _is_nvidia_hosted(self._config) and provider_model_id in self._rerank_model_endpoints: + ranking_url = self._rerank_model_endpoints[provider_model_id] + + logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}") + + # Convert query to text format + if isinstance(query, str): + query_text = query + elif isinstance(query, OpenAIChatCompletionContentPartTextParam): + query_text = query.text + else: + raise ValueError("Query must be a string or text content part") + + # Convert items to text format + passages = [] + for item in items: + if isinstance(item, str): + passages.append({"text": item}) + elif isinstance(item, OpenAIChatCompletionContentPartTextParam): + passages.append({"text": item.text}) + else: + raise ValueError("Items must be strings or text content parts") + + payload = { + "model": provider_model_id, + "query": {"text": query_text}, + "passages": passages, + } + + headers = { + "Authorization": f"Bearer {self.get_api_key()}", + "Content-Type": "application/json", + } + + try: + async with aiohttp.ClientSession() as session: + async with session.post(ranking_url, headers=headers, json=payload) as response: + if response.status != 200: + response_text = await response.text() + raise ConnectionError( + f"NVIDIA rerank API request failed with status {response.status}: {response_text}" + ) + + result = await response.json() + rankings = result.get("rankings", []) + + # Convert to RerankData format + rerank_data = [] + for ranking in rankings: + rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"])) + + # Apply max_num_results limit + if max_num_results is not None: + rerank_data = rerank_data[:max_num_results] + + return RerankResponse(data=rerank_data) + + except aiohttp.ClientError as e: + raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e + async def openai_embeddings( self, model: str, diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 4354b067e..da56374c5 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -63,6 +63,10 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC): # Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}} embedding_model_metadata: dict[str, dict[str, int]] = {} + # List of rerank model IDs for this provider + # Can be set by subclasses or instances to provide rerank models + rerank_model_list: list[str] = [] + # Cache of available models keyed by model ID # This is set in list_models() and used in check_model_availability() _model_cache: dict[str, Model] = {} @@ -400,6 +404,14 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC): model_type=ModelType.embedding, metadata=metadata, ) + elif m.id in self.rerank_model_list: + # This is a rerank model + model = Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=m.id, + identifier=m.id, + model_type=ModelType.rerank, + ) else: # This is an LLM model = Model( diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 4735264c3..2ad4f7e4c 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -120,6 +120,10 @@ def pytest_addoption(parser): "--embedding-model", help="comma-separated list of embedding models. Fixture name: embedding_model_id", ) + parser.addoption( + "--rerank-model", + help="comma-separated list of rerank models. Fixture name: rerank_model_id", + ) parser.addoption( "--safety-shield", help="comma-separated list of safety shields. Fixture name: shield_id", @@ -198,6 +202,7 @@ def pytest_generate_tests(metafunc): "shield_id": ("--safety-shield", "shield"), "judge_model_id": ("--judge-model", "judge"), "embedding_dimension": ("--embedding-dimension", "dim"), + "rerank_model_id": ("--rerank-model", "rerank"), } # Collect all parameters and their values diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 68aa2b60b..dfbcf476d 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -119,6 +119,7 @@ def client_with_models( embedding_model_id, embedding_dimension, judge_model_id, + rerank_model_id, ): client = llama_stack_client @@ -151,6 +152,20 @@ def client_with_models( model_type="embedding", metadata={"embedding_dimension": embedding_dimension or 384}, ) + if rerank_model_id and rerank_model_id not in model_ids: + selected_provider = None + for p in providers: + # Currently only NVIDIA inference provider supports reranking + if p.provider_type == "remote::nvidia": + selected_provider = p + break + + selected_provider = selected_provider or providers[0] + client.models.register( + model_id=rerank_model_id, + provider_id=selected_provider.provider_id, + model_type="rerank", + ) return client @@ -166,7 +181,14 @@ def model_providers(llama_stack_client): @pytest.fixture(autouse=True) def skip_if_no_model(request): - model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"] + model_fixtures = [ + "text_model_id", + "vision_model_id", + "embedding_model_id", + "judge_model_id", + "shield_id", + "rerank_model_id", + ] test_func = request.node.function actual_params = inspect.signature(test_func).parameters.keys() diff --git a/tests/integration/inference/test_rerank.py b/tests/integration/inference/test_rerank.py new file mode 100644 index 000000000..82f35cd27 --- /dev/null +++ b/tests/integration/inference/test_rerank.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +from llama_stack_client import BadRequestError as LlamaStackBadRequestError +from llama_stack_client.types.alpha import InferenceRerankResponse +from llama_stack_client.types.shared.interleaved_content import ( + ImageContentItem, + ImageContentItemImage, + ImageContentItemImageURL, + TextContentItem, +) + +from llama_stack.core.library_client import LlamaStackAsLibraryClient + +# Test data +DUMMY_STRING = "string_1" +DUMMY_STRING2 = "string_2" +DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text") +DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text") +DUMMY_IMAGE_URL = ImageContentItem( + image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image" +) +DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image") + +PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models + + +def skip_if_provider_doesnt_support_rerank(inference_provider_type): + supported_providers = {"remote::nvidia"} + if inference_provider_type not in supported_providers: + pytest.skip(f"{inference_provider_type} doesn't support rerank models") + + +def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None: + """ + Validate that a rerank response has the correct structure and ordering. + + Args: + response: The InferenceRerankResponse to validate + items: The original items list that was ranked + + Raises: + AssertionError: If any validation fails + """ + seen = set() + last_score = float("inf") + for d in response: + assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items" + assert d.index not in seen, f"Duplicate index {d.index} found" + seen.add(d.index) + assert isinstance(d.relevance_score, float), f"Score must be float, got {type(d.relevance_score)}" + assert d.relevance_score <= last_score, f"Scores not in descending order: {d.relevance_score} > {last_score}" + last_score = d.relevance_score + + +def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, expected_first_item: str) -> None: + """ + Validate that the expected most relevant item ranks first. + + Args: + response: The InferenceRerankResponse to validate + items: The original items list that was ranked + expected_first_item: The expected first item in the ranking + + Raises: + AssertionError: If any validation fails + """ + if not response: + raise AssertionError("No ranking data returned in response") + + actual_first_index = response[0].index + actual_first_item = items[actual_first_index] + assert actual_first_item == expected_first_item, ( + f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead." + ) + + +@pytest.mark.parametrize( + "query,items", + [ + (DUMMY_STRING, [DUMMY_STRING, DUMMY_STRING2]), + (DUMMY_TEXT, [DUMMY_TEXT, DUMMY_TEXT2]), + (DUMMY_STRING, [DUMMY_STRING2, DUMMY_TEXT]), + (DUMMY_TEXT, [DUMMY_STRING, DUMMY_TEXT2]), + ], + ids=[ + "string-query-string-items", + "text-query-text-items", + "mixed-content-1", + "mixed-content-2", + ], +) +def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type): + skip_if_provider_doesnt_support_rerank(inference_provider_type) + + response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items) + assert isinstance(response, list) + # TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client. + assert len(response) <= len(items) + _validate_rerank_response(response, items) + + +@pytest.mark.parametrize( + "query,items", + [ + (DUMMY_IMAGE_URL, [DUMMY_STRING]), + (DUMMY_IMAGE_BASE64, [DUMMY_TEXT]), + (DUMMY_TEXT, [DUMMY_IMAGE_URL]), + (DUMMY_IMAGE_BASE64, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]), + (DUMMY_TEXT, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]), + ], + ids=[ + "image-query-url", + "image-query-base64", + "text-query-image-item", + "mixed-content-1", + "mixed-content-2", + ], +) +def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type): + skip_if_provider_doesnt_support_rerank(inference_provider_type) + + if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA: + error_type = ( + ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError + ) + with pytest.raises(error_type): + client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items) + else: + response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items) + + assert isinstance(response, list) + assert len(response) <= len(items) + _validate_rerank_response(response, items) + + +def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type): + skip_if_provider_doesnt_support_rerank(inference_provider_type) + + items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2] + max_num_results = 2 + + response = client_with_models.alpha.inference.rerank( + model=rerank_model_id, + query=DUMMY_STRING, + items=items, + max_num_results=max_num_results, + ) + + assert isinstance(response, list) + assert len(response) == max_num_results + _validate_rerank_response(response, items) + + +def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type): + skip_if_provider_doesnt_support_rerank(inference_provider_type) + + items = [DUMMY_STRING, DUMMY_STRING2] + response = client_with_models.alpha.inference.rerank( + model=rerank_model_id, + query=DUMMY_STRING, + items=items, + max_num_results=10, # Larger than items length + ) + + assert isinstance(response, list) + assert len(response) <= len(items) # Should return at most len(items) + + +@pytest.mark.parametrize( + "query,items,expected_first_item", + [ + ( + "What is a reranking model? ", + [ + "A reranking model reranks a list of items based on the query. ", + "Machine learning algorithms learn patterns from data. ", + "Python is a programming language. ", + ], + "A reranking model reranks a list of items based on the query. ", + ), + ( + "What is C++?", + [ + "Learning new things is interesting. ", + "C++ is a programming language. ", + "Books provide knowledge and entertainment. ", + ], + "C++ is a programming language. ", + ), + ( + "What are good learning habits? ", + [ + "Cooking pasta is a fun activity. ", + "Plants need water and sunlight. ", + "Good learning habits include reading daily and taking notes. ", + ], + "Good learning habits include reading daily and taking notes. ", + ), + ], +) +def test_rerank_semantic_correctness( + client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type +): + skip_if_provider_doesnt_support_rerank(inference_provider_type) + + response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items) + + _validate_rerank_response(response, items) + _validate_semantic_ranking(response, items, expected_first_item) diff --git a/tests/unit/providers/nvidia/test_rerank_inference.py b/tests/unit/providers/nvidia/test_rerank_inference.py new file mode 100644 index 000000000..60891e496 --- /dev/null +++ b/tests/unit/providers/nvidia/test_rerank_inference.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, MagicMock, patch + +import aiohttp +import pytest + +from llama_stack.apis.models import ModelType +from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig +from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter + + +class MockResponse: + def __init__(self, status=200, json_data=None, text_data="OK"): + self.status = status + self._json_data = json_data or {"rankings": []} + self._text_data = text_data + + async def json(self): + return self._json_data + + async def text(self): + return self._text_data + + +class MockSession: + def __init__(self, response): + self.response = response + self.post_calls = [] + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return False + + def post(self, url, **kwargs): + self.post_calls.append((url, kwargs)) + + class PostContext: + def __init__(self, response): + self.response = response + + async def __aenter__(self): + return self.response + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return False + + return PostContext(self.response) + + +def create_adapter(config=None, rerank_endpoints=None): + if config is None: + config = NVIDIAConfig(api_key="test-key") + + adapter = NVIDIAInferenceAdapter(config) + + class MockModel: + provider_resource_id = "test-model" + metadata = {} + + adapter.model_store = AsyncMock() + adapter.model_store.get_model = AsyncMock(return_value=MockModel()) + + if rerank_endpoints is not None: + adapter._rerank_model_endpoints = rerank_endpoints + + return adapter + + +async def test_rerank_basic_functionality(): + adapter = create_adapter() + mock_response = MockResponse(json_data={"rankings": [{"index": 0, "logit": 0.5}]}) + mock_session = MockSession(mock_response) + + with patch("aiohttp.ClientSession", return_value=mock_session): + result = await adapter.rerank(model="test-model", query="test query", items=["item1", "item2"]) + + assert len(result.data) == 1 + assert result.data[0].index == 0 + assert result.data[0].relevance_score == 0.5 + + url, kwargs = mock_session.post_calls[0] + payload = kwargs["json"] + assert payload["model"] == "test-model" + assert payload["query"] == {"text": "test query"} + assert payload["passages"] == [{"text": "item1"}, {"text": "item2"}] + + +async def test_missing_rankings_key(): + adapter = create_adapter() + mock_session = MockSession(MockResponse(json_data={})) + + with patch("aiohttp.ClientSession", return_value=mock_session): + result = await adapter.rerank(model="test-model", query="q", items=["a"]) + + assert len(result.data) == 0 + + +async def test_hosted_with_endpoint(): + adapter = create_adapter( + config=NVIDIAConfig(api_key="key"), rerank_endpoints={"test-model": "https://model.endpoint/rerank"} + ) + mock_session = MockSession(MockResponse()) + + with patch("aiohttp.ClientSession", return_value=mock_session): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + url, _ = mock_session.post_calls[0] + assert url == "https://model.endpoint/rerank" + + +async def test_hosted_without_endpoint(): + adapter = create_adapter( + config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com). + rerank_endpoints={}, # No endpoint mapping for test-model + ) + mock_session = MockSession(MockResponse()) + + with patch("aiohttp.ClientSession", return_value=mock_session): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + url, _ = mock_session.post_calls[0] + assert "https://integrate.api.nvidia.com" in url + + +async def test_hosted_model_not_in_endpoint_mapping(): + adapter = create_adapter( + config=NVIDIAConfig(api_key="key"), rerank_endpoints={"other-model": "https://other.endpoint/rerank"} + ) + mock_session = MockSession(MockResponse()) + + with patch("aiohttp.ClientSession", return_value=mock_session): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + url, _ = mock_session.post_calls[0] + assert "https://integrate.api.nvidia.com" in url + assert url != "https://other.endpoint/rerank" + + +async def test_self_hosted_ignores_endpoint(): + adapter = create_adapter( + config=NVIDIAConfig(url="http://localhost:8000", api_key=None), + rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted. + ) + mock_session = MockSession(MockResponse()) + + with patch("aiohttp.ClientSession", return_value=mock_session): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + url, _ = mock_session.post_calls[0] + assert "http://localhost:8000" in url + assert "model.endpoint/rerank" not in url + + +async def test_max_num_results(): + adapter = create_adapter() + rankings = [{"index": 0, "logit": 0.8}, {"index": 1, "logit": 0.6}] + mock_session = MockSession(MockResponse(json_data={"rankings": rankings})) + + with patch("aiohttp.ClientSession", return_value=mock_session): + result = await adapter.rerank(model="test-model", query="q", items=["a", "b"], max_num_results=1) + + assert len(result.data) == 1 + assert result.data[0].index == 0 + assert result.data[0].relevance_score == 0.8 + + +async def test_http_error(): + adapter = create_adapter() + mock_session = MockSession(MockResponse(status=500, text_data="Server Error")) + + with patch("aiohttp.ClientSession", return_value=mock_session): + with pytest.raises(ConnectionError, match="status 500.*Server Error"): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + +async def test_client_error(): + adapter = create_adapter() + mock_session = AsyncMock() + mock_session.__aenter__.side_effect = aiohttp.ClientError("Network error") + + with patch("aiohttp.ClientSession", return_value=mock_session): + with pytest.raises(ConnectionError, match="Failed to connect.*Network error"): + await adapter.rerank(model="test-model", query="q", items=["a"]) + + +async def test_list_models_adds_rerank_models(): + """Test that list_models adds rerank models to the dynamic model list.""" + adapter = create_adapter() + adapter.__provider_id__ = "nvidia" + + # Mock the list_models from the superclass to return some dynamic models + base_models = [ + MagicMock(identifier="llm-1", model_type=ModelType.llm), + MagicMock(identifier="embedding-1", model_type=ModelType.embedding), + ] + + with patch.object(NVIDIAInferenceAdapter.__bases__[0], "list_models", return_value=base_models): + result = await adapter.list_models() + + assert result is not None + + # Check that the rerank models are added + model_ids = [m.identifier for m in result] + assert "nv-rerank-qa-mistral-4b:1" in model_ids + assert "nvidia/nv-rerankqa-mistral-4b-v3" in model_ids + assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in model_ids + + rerank_models = [m for m in result if m.model_type == ModelType.rerank] + + assert len(rerank_models) == 3 + + for rerank_model in rerank_models: + assert rerank_model.provider_id == "nvidia" + assert rerank_model.metadata == {} + assert rerank_model.identifier in adapter._model_cache diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 4856f510b..937caa1c0 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -35,6 +35,40 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl): } +class OpenAIMixinWithRerankImpl(OpenAIMixin): + """Test implementation with rerank model list""" + + rerank_model_list = ["rerank-model-1", "rerank-model-2"] + + def __init__(self): + self.__provider_id__ = "test-provider" + + def get_api_key(self) -> str: + raise NotImplementedError("This method should be mocked in tests") + + def get_base_url(self) -> str: + raise NotImplementedError("This method should be mocked in tests") + + +class OpenAIMixinWithEmbeddingsAndRerankImpl(OpenAIMixin): + """Test implementation with both embedding model metadata and rerank model list""" + + embedding_model_metadata = { + "text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192}, + "text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192}, + } + + rerank_model_list = ["rerank-model-1", "rerank-model-2"] + + __provider_id__ = "test-provider" + + def get_api_key(self) -> str: + raise NotImplementedError("This method should be mocked in tests") + + def get_base_url(self) -> str: + raise NotImplementedError("This method should be mocked in tests") + + @pytest.fixture def mixin(): """Create a test instance of OpenAIMixin with mocked model_store""" @@ -56,6 +90,18 @@ def mixin_with_embeddings(): return OpenAIMixinWithEmbeddingsImpl() +@pytest.fixture +def mixin_with_rerank(): + """Create a test instance of OpenAIMixin with rerank model list""" + return OpenAIMixinWithRerankImpl() + + +@pytest.fixture +def mixin_with_embeddings_and_rerank(): + """Create a test instance of OpenAIMixin with both embedding model metadata and rerank model list""" + return OpenAIMixinWithEmbeddingsAndRerankImpl() + + @pytest.fixture def mock_models(): """Create multiple mock OpenAI model objects""" @@ -107,6 +153,19 @@ def mock_client_context(): return _mock_client_context +def _assert_models_match_expected(actual_models, expected_models): + """Verify the models match expected attributes. + + Args: + actual_models: List of models to verify + expected_models: Mapping of model identifier to expected attribute values + """ + for identifier, expected_attrs in expected_models.items(): + model = next(m for m in actual_models if m.identifier == identifier) + for attr_name, expected_value in expected_attrs.items(): + assert getattr(model, attr_name) == expected_value + + class TestOpenAIMixinListModels: """Test cases for the list_models method""" @@ -300,21 +359,113 @@ class TestOpenAIMixinEmbeddingModelMetadata: assert result is not None assert len(result) == 2 - # Find the models in the result - embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small") - llm_model = next(m for m in result if m.identifier == "gpt-4") + expected_models = { + "text-embedding-3-small": { + "model_type": ModelType.embedding, + "metadata": {"embedding_dimension": 1536, "context_length": 8192}, + "provider_id": "test-provider", + "provider_resource_id": "text-embedding-3-small", + }, + "gpt-4": { + "model_type": ModelType.llm, + "metadata": {}, + "provider_id": "test-provider", + "provider_resource_id": "gpt-4", + }, + } - # Check embedding model - assert embedding_model.model_type == ModelType.embedding - assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192} - assert embedding_model.provider_id == "test-provider" - assert embedding_model.provider_resource_id == "text-embedding-3-small" + _assert_models_match_expected(result, expected_models) - # Check LLM model - assert llm_model.model_type == ModelType.llm - assert llm_model.metadata == {} # No metadata for LLMs - assert llm_model.provider_id == "test-provider" - assert llm_model.provider_resource_id == "gpt-4" + +class TestOpenAIMixinRerankModelList: + """Test cases for rerank_model_list attribute functionality""" + + async def test_rerank_model_identified(self, mixin_with_rerank, mock_client_context): + """Test that models in rerank_model_list are correctly identified as rerank models""" + # Create mock models: 1 rerank model and 1 LLM + mock_rerank_model = MagicMock(id="rerank-model-1") + mock_llm_model = MagicMock(id="gpt-4") + mock_models = [mock_rerank_model, mock_llm_model] + + mock_client = MagicMock() + + async def mock_models_list(): + for model in mock_models: + yield model + + mock_client.models.list.return_value = mock_models_list() + + with mock_client_context(mixin_with_rerank, mock_client): + result = await mixin_with_rerank.list_models() + + assert result is not None + assert len(result) == 2 + + expected_models = { + "rerank-model-1": { + "model_type": ModelType.rerank, + "metadata": {}, + "provider_id": "test-provider", + "provider_resource_id": "rerank-model-1", + }, + "gpt-4": { + "model_type": ModelType.llm, + "metadata": {}, + "provider_id": "test-provider", + "provider_resource_id": "gpt-4", + }, + } + + _assert_models_match_expected(result, expected_models) + + +class TestOpenAIMixinMixedModelTypes: + """Test cases for mixed model types (LLM, embedding, rerank)""" + + async def test_mixed_model_types_identification(self, mixin_with_embeddings_and_rerank, mock_client_context): + """Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata""" + # Create mock models: 1 embedding, 1 rerank, 1 LLM + mock_embedding_model = MagicMock(id="text-embedding-3-small") + mock_rerank_model = MagicMock(id="rerank-model-1") + mock_llm_model = MagicMock(id="gpt-4") + mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model] + + mock_client = MagicMock() + + async def mock_models_list(): + for model in mock_models: + yield model + + mock_client.models.list.return_value = mock_models_list() + + with mock_client_context(mixin_with_embeddings_and_rerank, mock_client): + result = await mixin_with_embeddings_and_rerank.list_models() + + assert result is not None + assert len(result) == 3 + + expected_models = { + "text-embedding-3-small": { + "model_type": ModelType.embedding, + "metadata": {"embedding_dimension": 1536, "context_length": 8192}, + "provider_id": "test-provider", + "provider_resource_id": "text-embedding-3-small", + }, + "rerank-model-1": { + "model_type": ModelType.rerank, + "metadata": {}, + "provider_id": "test-provider", + "provider_resource_id": "rerank-model-1", + }, + "gpt-4": { + "model_type": ModelType.llm, + "metadata": {}, + "provider_id": "test-provider", + "provider_resource_id": "gpt-4", + }, + } + + _assert_models_match_expected(result, expected_models) class TestOpenAIMixinAllowedModels: