From 816b68fdc7cc83288a4548f3c73c6285fe5c86d9 Mon Sep 17 00:00:00 2001 From: Jiayi Date: Sun, 28 Sep 2025 14:45:16 -0700 Subject: [PATCH] Add rerank models to the dynamic model list; Fix integration tests --- docs/docs/providers/batches/index.mdx | 12 +- docs/docs/providers/inference/index.mdx | 1 + .../remote/inference/nvidia/NVIDIA.md | 4 +- .../remote/inference/nvidia/nvidia.py | 39 +++++ .../providers/utils/inference/openai_mixin.py | 12 ++ tests/integration/inference/test_rerank.py | 33 ++--- .../providers/nvidia/test_rerank_inference.py | 35 ++++- .../utils/inference/test_openai_mixin.py | 136 ++++++++++++++++++ 8 files changed, 247 insertions(+), 25 deletions(-) diff --git a/docs/docs/providers/batches/index.mdx b/docs/docs/providers/batches/index.mdx index 2c64b277f..85213ab17 100644 --- a/docs/docs/providers/batches/index.mdx +++ b/docs/docs/providers/batches/index.mdx @@ -18,14 +18,14 @@ title: Batches ## Overview The Batches API enables efficient processing of multiple requests in a single operation, - particularly useful for processing large datasets, batch evaluation workflows, and - cost-effective inference at scale. +particularly useful for processing large datasets, batch evaluation workflows, and +cost-effective inference at scale. - The API is designed to allow use of openai client libraries for seamless integration. +The API is designed to allow use of openai client libraries for seamless integration. - This API provides the following extensions: - - idempotent batch creation +This API provides the following extensions: + - idempotent batch creation - Note: This API is currently under active development and may undergo changes. +Note: This API is currently under active development and may undergo changes. This section contains documentation for all available providers for the **batches** API. diff --git a/docs/docs/providers/inference/index.mdx b/docs/docs/providers/inference/index.mdx index 98ba10cc7..065f620df 100644 --- a/docs/docs/providers/inference/index.mdx +++ b/docs/docs/providers/inference/index.mdx @@ -5,6 +5,7 @@ description: "Llama Stack Inference API for generating completions, chat complet - LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. - Rerank models: these models rerank the documents by relevance." + sidebar_label: Inference title: Inference --- diff --git a/llama_stack/providers/remote/inference/nvidia/NVIDIA.md b/llama_stack/providers/remote/inference/nvidia/NVIDIA.md index c683c7a68..dcc9d3909 100644 --- a/llama_stack/providers/remote/inference/nvidia/NVIDIA.md +++ b/llama_stack/providers/remote/inference/nvidia/NVIDIA.md @@ -204,6 +204,6 @@ rerank_response = client.inference.rerank( ], ) -for i, result in enumerate(rerank_response.data): - print(f"{i+1}. [Index: {result.index}, Score: {result.relevance_score:.3f}]") +for i, result in enumerate(rerank_response): + print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]") ``` \ No newline at end of file diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index f629d8c19..ae9245bfe 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -20,6 +20,7 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartTextParam, ) +from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin @@ -51,6 +52,18 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference): "snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024}, } + rerank_model_list = [ + "nv-rerank-qa-mistral-4b:1", + "nvidia/nv-rerankqa-mistral-4b-v3", + "nvidia/llama-3.2-nv-rerankqa-1b-v2", + ] + + _rerank_model_endpoints = { + "nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking", + "nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking", + "nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking", + } + def __init__(self, config: NVIDIAConfig) -> None: logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...") @@ -69,6 +82,8 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference): # "Consider removing the api_key from the configuration." # ) + super().__init__() + self._config = config def get_api_key(self) -> str: @@ -87,6 +102,30 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference): """ return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url + async def list_models(self) -> list[Model] | None: + """ + List available NVIDIA models by combining: + 1. Dynamic models from https://integrate.api.nvidia.com/v1/models + 2. Static rerank models (which use different API endpoints) + """ + models = await super().list_models() or [] + + existing_ids = {m.identifier for m in models} + for model_id, _ in self._rerank_model_endpoints.items(): + if self.allowed_models and model_id not in self.allowed_models: + continue + if model_id not in existing_ids: + model = Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=model_id, + identifier=model_id, + model_type=ModelType.rerank, + ) + models.append(model) + self._model_cache[model_id] = model + + return models + async def rerank( self, model: str, diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 4354b067e..da56374c5 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -63,6 +63,10 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC): # Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}} embedding_model_metadata: dict[str, dict[str, int]] = {} + # List of rerank model IDs for this provider + # Can be set by subclasses or instances to provide rerank models + rerank_model_list: list[str] = [] + # Cache of available models keyed by model ID # This is set in list_models() and used in check_model_availability() _model_cache: dict[str, Model] = {} @@ -400,6 +404,14 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC): model_type=ModelType.embedding, metadata=metadata, ) + elif m.id in self.rerank_model_list: + # This is a rerank model + model = Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=m.id, + identifier=m.id, + model_type=ModelType.rerank, + ) else: # This is an LLM model = Model( diff --git a/tests/integration/inference/test_rerank.py b/tests/integration/inference/test_rerank.py index f1b9311a4..ea17a54cb 100644 --- a/tests/integration/inference/test_rerank.py +++ b/tests/integration/inference/test_rerank.py @@ -6,7 +6,7 @@ import pytest from llama_stack_client import BadRequestError as LlamaStackBadRequestError -from llama_stack_client.types import RerankResponse +from llama_stack_client.types import InferenceRerankResponse from llama_stack_client.types.shared.interleaved_content import ( ImageContentItem, ImageContentItemImage, @@ -30,12 +30,12 @@ SUPPORTED_PROVIDERS = {"remote::nvidia"} PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models -def _validate_rerank_response(response: RerankResponse, items: list) -> None: +def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None: """ Validate that a rerank response has the correct structure and ordering. Args: - response: The RerankResponse to validate + response: The InferenceRerankResponse to validate items: The original items list that was ranked Raises: @@ -43,7 +43,7 @@ def _validate_rerank_response(response: RerankResponse, items: list) -> None: """ seen = set() last_score = float("inf") - for d in response.data: + for d in response: assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items" assert d.index not in seen, f"Duplicate index {d.index} found" seen.add(d.index) @@ -52,22 +52,22 @@ def _validate_rerank_response(response: RerankResponse, items: list) -> None: last_score = d.relevance_score -def _validate_semantic_ranking(response: RerankResponse, items: list, expected_first_item: str) -> None: +def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, expected_first_item: str) -> None: """ Validate that the expected most relevant item ranks first. Args: - response: The RerankResponse to validate + response: The InferenceRerankResponse to validate items: The original items list that was ranked expected_first_item: The expected first item in the ranking Raises: AssertionError: If any validation fails """ - if not response.data: + if not response: raise AssertionError("No ranking data returned in response") - actual_first_index = response.data[0].index + actual_first_index = response[0].index actual_first_item = items[actual_first_index] assert actual_first_item == expected_first_item, ( f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead." @@ -94,8 +94,9 @@ def test_rerank_text(client_with_models, rerank_model_id, query, items, inferenc pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ") response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items) - assert isinstance(response, RerankResponse) - assert len(response.data) <= len(items) + assert isinstance(response, list) + # TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client. + assert len(response) <= len(items) _validate_rerank_response(response, items) @@ -129,8 +130,8 @@ def test_rerank_image(client_with_models, rerank_model_id, query, items, inferen else: response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items) - assert isinstance(response, RerankResponse) - assert len(response.data) <= len(items) + assert isinstance(response, list) + assert len(response) <= len(items) _validate_rerank_response(response, items) @@ -148,8 +149,8 @@ def test_rerank_max_results(client_with_models, rerank_model_id, inference_provi max_num_results=max_num_results, ) - assert isinstance(response, RerankResponse) - assert len(response.data) == max_num_results + assert isinstance(response, list) + assert len(response) == max_num_results _validate_rerank_response(response, items) @@ -165,8 +166,8 @@ def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_i max_num_results=10, # Larger than items length ) - assert isinstance(response, RerankResponse) - assert len(response.data) <= len(items) # Should return at most len(items) + assert isinstance(response, list) + assert len(response) <= len(items) # Should return at most len(items) @pytest.mark.parametrize( diff --git a/tests/unit/providers/nvidia/test_rerank_inference.py b/tests/unit/providers/nvidia/test_rerank_inference.py index 687ffd502..f34518609 100644 --- a/tests/unit/providers/nvidia/test_rerank_inference.py +++ b/tests/unit/providers/nvidia/test_rerank_inference.py @@ -4,11 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import aiohttp import pytest +from llama_stack.apis.models import ModelType from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter @@ -170,3 +171,35 @@ async def test_client_error(): with patch("aiohttp.ClientSession", return_value=mock_session): with pytest.raises(ConnectionError, match="Failed to connect.*Network error"): await adapter.rerank(model="test-model", query="q", items=["a"]) + + +async def test_list_models_adds_rerank_models(): + """Test that list_models adds rerank models to the dynamic model list.""" + adapter = create_adapter() + adapter.__provider_id__ = "nvidia" + + # Mock the list_models from the superclass to return some dynamic models + base_models = [ + MagicMock(identifier="llm-1", model_type=ModelType.llm), + MagicMock(identifier="embedding-1", model_type=ModelType.embedding), + ] + + with patch.object(NVIDIAInferenceAdapter.__bases__[0], "list_models", return_value=base_models): + result = await adapter.list_models() + + assert result is not None + + # Check that the rerank models are added + model_ids = [m.identifier for m in result] + assert "nv-rerank-qa-mistral-4b:1" in model_ids + assert "nvidia/nv-rerankqa-mistral-4b-v3" in model_ids + assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in model_ids + + rerank_models = [m for m in result if m.model_type == ModelType.rerank] + + assert len(rerank_models) == 3 + + for rerank_model in rerank_models: + assert rerank_model.provider_id == "nvidia" + assert rerank_model.metadata == {} + assert rerank_model.identifier in adapter._model_cache diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 4856f510b..ae723dcc2 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -35,6 +35,40 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl): } +class OpenAIMixinWithRerankImpl(OpenAIMixin): + """Test implementation with rerank model list""" + + rerank_model_list = ["rerank-model-1", "rerank-model-2"] + + def __init__(self): + self.__provider_id__ = "test-provider" + + def get_api_key(self) -> str: + raise NotImplementedError("This method should be mocked in tests") + + def get_base_url(self) -> str: + raise NotImplementedError("This method should be mocked in tests") + + +class OpenAIMixinWithEmbeddingsAndRerankImpl(OpenAIMixin): + """Test implementation with both embedding model metadata and rerank model list""" + + embedding_model_metadata = { + "text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192}, + "text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192}, + } + + rerank_model_list = ["rerank-model-1", "rerank-model-2"] + + __provider_id__ = "test-provider" + + def get_api_key(self) -> str: + raise NotImplementedError("This method should be mocked in tests") + + def get_base_url(self) -> str: + raise NotImplementedError("This method should be mocked in tests") + + @pytest.fixture def mixin(): """Create a test instance of OpenAIMixin with mocked model_store""" @@ -56,6 +90,18 @@ def mixin_with_embeddings(): return OpenAIMixinWithEmbeddingsImpl() +@pytest.fixture +def mixin_with_rerank(): + """Create a test instance of OpenAIMixin with rerank model list""" + return OpenAIMixinWithRerankImpl() + + +@pytest.fixture +def mixin_with_embeddings_and_rerank(): + """Create a test instance of OpenAIMixin with both embedding model metadata and rerank model list""" + return OpenAIMixinWithEmbeddingsAndRerankImpl() + + @pytest.fixture def mock_models(): """Create multiple mock OpenAI model objects""" @@ -317,6 +363,96 @@ class TestOpenAIMixinEmbeddingModelMetadata: assert llm_model.provider_resource_id == "gpt-4" +class TestOpenAIMixinRerankModelList: + """Test cases for rerank_model_list attribute functionality""" + + async def test_rerank_model_identified(self, mixin_with_rerank, mock_client_context): + """Test that models in rerank_model_list are correctly identified as rerank models""" + # Create mock models: 1 rerank model and 1 LLM + mock_rerank_model = MagicMock(id="rerank-model-1") + mock_llm_model = MagicMock(id="gpt-4") + mock_models = [mock_rerank_model, mock_llm_model] + + mock_client = MagicMock() + + async def mock_models_list(): + for model in mock_models: + yield model + + mock_client.models.list.return_value = mock_models_list() + + with mock_client_context(mixin_with_rerank, mock_client): + result = await mixin_with_rerank.list_models() + + assert result is not None + assert len(result) == 2 + + # Find the models in the result + rerank_model = next(m for m in result if m.identifier == "rerank-model-1") + llm_model = next(m for m in result if m.identifier == "gpt-4") + + # Check rerank model + assert rerank_model.model_type == ModelType.rerank + assert rerank_model.metadata == {} # No metadata for rerank models + assert rerank_model.provider_id == "test-provider" + assert rerank_model.provider_resource_id == "rerank-model-1" + + # Check LLM model + assert llm_model.model_type == ModelType.llm + assert llm_model.metadata == {} # No metadata for LLMs + assert llm_model.provider_id == "test-provider" + assert llm_model.provider_resource_id == "gpt-4" + + +class TestOpenAIMixinMixedModelTypes: + """Test cases for mixed model types (LLM, embedding, rerank)""" + + async def test_mixed_model_types_identification(self, mixin_with_embeddings_and_rerank, mock_client_context): + """Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata""" + # Create mock models: 1 embedding, 1 rerank, 1 LLM + mock_embedding_model = MagicMock(id="text-embedding-3-small") + mock_rerank_model = MagicMock(id="rerank-model-1") + mock_llm_model = MagicMock(id="gpt-4") + mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model] + + mock_client = MagicMock() + + async def mock_models_list(): + for model in mock_models: + yield model + + mock_client.models.list.return_value = mock_models_list() + + with mock_client_context(mixin_with_embeddings_and_rerank, mock_client): + result = await mixin_with_embeddings_and_rerank.list_models() + + assert result is not None + assert len(result) == 3 + + # Find the models in the result + embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small") + rerank_model = next(m for m in result if m.identifier == "rerank-model-1") + llm_model = next(m for m in result if m.identifier == "gpt-4") + + # Check embedding model + assert embedding_model.model_type == ModelType.embedding + assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192} + assert embedding_model.provider_id == "test-provider" + assert embedding_model.provider_resource_id == "text-embedding-3-small" + + # Check rerank model + assert rerank_model.model_type == ModelType.rerank + assert rerank_model.metadata == {} # No metadata for rerank models + assert rerank_model.provider_id == "test-provider" + assert rerank_model.provider_resource_id == "rerank-model-1" + + # Check LLM model + assert llm_model.model_type == ModelType.llm + assert llm_model.metadata == {} # No metadata for LLMs + assert llm_model.provider_id == "test-provider" + assert llm_model.provider_resource_id == "gpt-4" + + class TestOpenAIMixinAllowedModels: """Test cases for allowed_models filtering functionality"""