Add rerank models to the dynamic model list; Fix integration tests

This commit is contained in:
Jiayi 2025-09-28 14:45:16 -07:00
parent 3538477070
commit 816b68fdc7
8 changed files with 247 additions and 25 deletions

View file

@ -18,14 +18,14 @@ title: Batches
## Overview
The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale.
particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale.
The API is designed to allow use of openai client libraries for seamless integration.
The API is designed to allow use of openai client libraries for seamless integration.
This API provides the following extensions:
- idempotent batch creation
This API provides the following extensions:
- idempotent batch creation
Note: This API is currently under active development and may undergo changes.
Note: This API is currently under active development and may undergo changes.
This section contains documentation for all available providers for the **batches** API.

View file

@ -5,6 +5,7 @@ description: "Llama Stack Inference API for generating completions, chat complet
- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
- Rerank models: these models rerank the documents by relevance."
sidebar_label: Inference
title: Inference
---

View file

@ -204,6 +204,6 @@ rerank_response = client.inference.rerank(
],
)
for i, result in enumerate(rerank_response.data):
print(f"{i+1}. [Index: {result.index}, Score: {result.relevance_score:.3f}]")
for i, result in enumerate(rerank_response):
print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]")
```

View file

@ -20,6 +20,7 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -51,6 +52,18 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
"snowflake/arctic-embed-l": {"embedding_dimension": 512, "context_length": 1024},
}
rerank_model_list = [
"nv-rerank-qa-mistral-4b:1",
"nvidia/nv-rerankqa-mistral-4b-v3",
"nvidia/llama-3.2-nv-rerankqa-1b-v2",
]
_rerank_model_endpoints = {
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
"nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
"nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
}
def __init__(self, config: NVIDIAConfig) -> None:
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
@ -69,6 +82,8 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
# "Consider removing the api_key from the configuration."
# )
super().__init__()
self._config = config
def get_api_key(self) -> str:
@ -87,6 +102,30 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference):
"""
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
async def list_models(self) -> list[Model] | None:
"""
List available NVIDIA models by combining:
1. Dynamic models from https://integrate.api.nvidia.com/v1/models
2. Static rerank models (which use different API endpoints)
"""
models = await super().list_models() or []
existing_ids = {m.identifier for m in models}
for model_id, _ in self._rerank_model_endpoints.items():
if self.allowed_models and model_id not in self.allowed_models:
continue
if model_id not in existing_ids:
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=model_id,
identifier=model_id,
model_type=ModelType.rerank,
)
models.append(model)
self._model_cache[model_id] = model
return models
async def rerank(
self,
model: str,

View file

@ -63,6 +63,10 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
embedding_model_metadata: dict[str, dict[str, int]] = {}
# List of rerank model IDs for this provider
# Can be set by subclasses or instances to provide rerank models
rerank_model_list: list[str] = []
# Cache of available models keyed by model ID
# This is set in list_models() and used in check_model_availability()
_model_cache: dict[str, Model] = {}
@ -400,6 +404,14 @@ class OpenAIMixin(ModelsProtocolPrivate, NeedsRequestProviderData, ABC):
model_type=ModelType.embedding,
metadata=metadata,
)
elif m.id in self.rerank_model_list:
# This is a rerank model
model = Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=m.id,
identifier=m.id,
model_type=ModelType.rerank,
)
else:
# This is an LLM
model = Model(

View file

@ -6,7 +6,7 @@
import pytest
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
from llama_stack_client.types import RerankResponse
from llama_stack_client.types import InferenceRerankResponse
from llama_stack_client.types.shared.interleaved_content import (
ImageContentItem,
ImageContentItemImage,
@ -30,12 +30,12 @@ SUPPORTED_PROVIDERS = {"remote::nvidia"}
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
def _validate_rerank_response(response: RerankResponse, items: list) -> None:
def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None:
"""
Validate that a rerank response has the correct structure and ordering.
Args:
response: The RerankResponse to validate
response: The InferenceRerankResponse to validate
items: The original items list that was ranked
Raises:
@ -43,7 +43,7 @@ def _validate_rerank_response(response: RerankResponse, items: list) -> None:
"""
seen = set()
last_score = float("inf")
for d in response.data:
for d in response:
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
assert d.index not in seen, f"Duplicate index {d.index} found"
seen.add(d.index)
@ -52,22 +52,22 @@ def _validate_rerank_response(response: RerankResponse, items: list) -> None:
last_score = d.relevance_score
def _validate_semantic_ranking(response: RerankResponse, items: list, expected_first_item: str) -> None:
def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, expected_first_item: str) -> None:
"""
Validate that the expected most relevant item ranks first.
Args:
response: The RerankResponse to validate
response: The InferenceRerankResponse to validate
items: The original items list that was ranked
expected_first_item: The expected first item in the ranking
Raises:
AssertionError: If any validation fails
"""
if not response.data:
if not response:
raise AssertionError("No ranking data returned in response")
actual_first_index = response.data[0].index
actual_first_index = response[0].index
actual_first_item = items[actual_first_index]
assert actual_first_item == expected_first_item, (
f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead."
@ -94,8 +94,9 @@ def test_rerank_text(client_with_models, rerank_model_id, query, items, inferenc
pytest.xfail(f"{inference_provider_type} doesn't support rerank models yet. ")
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items)
assert isinstance(response, list)
# TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client.
assert len(response) <= len(items)
_validate_rerank_response(response, items)
@ -129,8 +130,8 @@ def test_rerank_image(client_with_models, rerank_model_id, query, items, inferen
else:
response = client_with_models.inference.rerank(model=rerank_model_id, query=query, items=items)
assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items)
assert isinstance(response, list)
assert len(response) <= len(items)
_validate_rerank_response(response, items)
@ -148,8 +149,8 @@ def test_rerank_max_results(client_with_models, rerank_model_id, inference_provi
max_num_results=max_num_results,
)
assert isinstance(response, RerankResponse)
assert len(response.data) == max_num_results
assert isinstance(response, list)
assert len(response) == max_num_results
_validate_rerank_response(response, items)
@ -165,8 +166,8 @@ def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_i
max_num_results=10, # Larger than items length
)
assert isinstance(response, RerankResponse)
assert len(response.data) <= len(items) # Should return at most len(items)
assert isinstance(response, list)
assert len(response) <= len(items) # Should return at most len(items)
@pytest.mark.parametrize(

View file

@ -4,11 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
import pytest
from llama_stack.apis.models import ModelType
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
@ -170,3 +171,35 @@ async def test_client_error():
with patch("aiohttp.ClientSession", return_value=mock_session):
with pytest.raises(ConnectionError, match="Failed to connect.*Network error"):
await adapter.rerank(model="test-model", query="q", items=["a"])
async def test_list_models_adds_rerank_models():
"""Test that list_models adds rerank models to the dynamic model list."""
adapter = create_adapter()
adapter.__provider_id__ = "nvidia"
# Mock the list_models from the superclass to return some dynamic models
base_models = [
MagicMock(identifier="llm-1", model_type=ModelType.llm),
MagicMock(identifier="embedding-1", model_type=ModelType.embedding),
]
with patch.object(NVIDIAInferenceAdapter.__bases__[0], "list_models", return_value=base_models):
result = await adapter.list_models()
assert result is not None
# Check that the rerank models are added
model_ids = [m.identifier for m in result]
assert "nv-rerank-qa-mistral-4b:1" in model_ids
assert "nvidia/nv-rerankqa-mistral-4b-v3" in model_ids
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in model_ids
rerank_models = [m for m in result if m.model_type == ModelType.rerank]
assert len(rerank_models) == 3
for rerank_model in rerank_models:
assert rerank_model.provider_id == "nvidia"
assert rerank_model.metadata == {}
assert rerank_model.identifier in adapter._model_cache

View file

@ -35,6 +35,40 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
}
class OpenAIMixinWithRerankImpl(OpenAIMixin):
"""Test implementation with rerank model list"""
rerank_model_list = ["rerank-model-1", "rerank-model-2"]
def __init__(self):
self.__provider_id__ = "test-provider"
def get_api_key(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
def get_base_url(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
class OpenAIMixinWithEmbeddingsAndRerankImpl(OpenAIMixin):
"""Test implementation with both embedding model metadata and rerank model list"""
embedding_model_metadata = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
}
rerank_model_list = ["rerank-model-1", "rerank-model-2"]
__provider_id__ = "test-provider"
def get_api_key(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
def get_base_url(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
@pytest.fixture
def mixin():
"""Create a test instance of OpenAIMixin with mocked model_store"""
@ -56,6 +90,18 @@ def mixin_with_embeddings():
return OpenAIMixinWithEmbeddingsImpl()
@pytest.fixture
def mixin_with_rerank():
"""Create a test instance of OpenAIMixin with rerank model list"""
return OpenAIMixinWithRerankImpl()
@pytest.fixture
def mixin_with_embeddings_and_rerank():
"""Create a test instance of OpenAIMixin with both embedding model metadata and rerank model list"""
return OpenAIMixinWithEmbeddingsAndRerankImpl()
@pytest.fixture
def mock_models():
"""Create multiple mock OpenAI model objects"""
@ -317,6 +363,96 @@ class TestOpenAIMixinEmbeddingModelMetadata:
assert llm_model.provider_resource_id == "gpt-4"
class TestOpenAIMixinRerankModelList:
"""Test cases for rerank_model_list attribute functionality"""
async def test_rerank_model_identified(self, mixin_with_rerank, mock_client_context):
"""Test that models in rerank_model_list are correctly identified as rerank models"""
# Create mock models: 1 rerank model and 1 LLM
mock_rerank_model = MagicMock(id="rerank-model-1")
mock_llm_model = MagicMock(id="gpt-4")
mock_models = [mock_rerank_model, mock_llm_model]
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
with mock_client_context(mixin_with_rerank, mock_client):
result = await mixin_with_rerank.list_models()
assert result is not None
assert len(result) == 2
# Find the models in the result
rerank_model = next(m for m in result if m.identifier == "rerank-model-1")
llm_model = next(m for m in result if m.identifier == "gpt-4")
# Check rerank model
assert rerank_model.model_type == ModelType.rerank
assert rerank_model.metadata == {} # No metadata for rerank models
assert rerank_model.provider_id == "test-provider"
assert rerank_model.provider_resource_id == "rerank-model-1"
# Check LLM model
assert llm_model.model_type == ModelType.llm
assert llm_model.metadata == {} # No metadata for LLMs
assert llm_model.provider_id == "test-provider"
assert llm_model.provider_resource_id == "gpt-4"
class TestOpenAIMixinMixedModelTypes:
"""Test cases for mixed model types (LLM, embedding, rerank)"""
async def test_mixed_model_types_identification(self, mixin_with_embeddings_and_rerank, mock_client_context):
"""Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata"""
# Create mock models: 1 embedding, 1 rerank, 1 LLM
mock_embedding_model = MagicMock(id="text-embedding-3-small")
mock_rerank_model = MagicMock(id="rerank-model-1")
mock_llm_model = MagicMock(id="gpt-4")
mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model]
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
with mock_client_context(mixin_with_embeddings_and_rerank, mock_client):
result = await mixin_with_embeddings_and_rerank.list_models()
assert result is not None
assert len(result) == 3
# Find the models in the result
embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small")
rerank_model = next(m for m in result if m.identifier == "rerank-model-1")
llm_model = next(m for m in result if m.identifier == "gpt-4")
# Check embedding model
assert embedding_model.model_type == ModelType.embedding
assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192}
assert embedding_model.provider_id == "test-provider"
assert embedding_model.provider_resource_id == "text-embedding-3-small"
# Check rerank model
assert rerank_model.model_type == ModelType.rerank
assert rerank_model.metadata == {} # No metadata for rerank models
assert rerank_model.provider_id == "test-provider"
assert rerank_model.provider_resource_id == "rerank-model-1"
# Check LLM model
assert llm_model.model_type == ModelType.llm
assert llm_model.metadata == {} # No metadata for LLMs
assert llm_model.provider_id == "test-provider"
assert llm_model.provider_resource_id == "gpt-4"
class TestOpenAIMixinAllowedModels:
"""Test cases for allowed_models filtering functionality"""