From 51c923f09616f3cc1905219c22e7ffcdfecadcbe Mon Sep 17 00:00:00 2001 From: Jiayi Date: Thu, 16 Oct 2025 17:27:38 -0700 Subject: [PATCH] Add rerank models and rerank API change --- docs/docs/providers/inference/index.mdx | 8 +- docs/static/deprecated-llama-stack-spec.html | 2 +- docs/static/deprecated-llama-stack-spec.yaml | 7 +- docs/static/llama-stack-spec.html | 5 +- docs/static/llama-stack-spec.yaml | 8 +- docs/static/stainless-llama-stack-spec.html | 5 +- docs/static/stainless-llama-stack-spec.yaml | 8 +- llama_stack/apis/inference/inference.py | 3 +- llama_stack/apis/models/models.py | 2 + llama_stack/core/routers/inference.py | 22 +++ .../providers/utils/inference/openai_mixin.py | 11 ++ .../utils/inference/test_openai_mixin.py | 162 ++++++++++++++++-- 12 files changed, 215 insertions(+), 28 deletions(-) diff --git a/docs/docs/providers/inference/index.mdx b/docs/docs/providers/inference/index.mdx index c2bf69962..bc31caf5f 100644 --- a/docs/docs/providers/inference/index.mdx +++ b/docs/docs/providers/inference/index.mdx @@ -3,9 +3,10 @@ description: "Inference Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: + This API provides the raw interface to the underlying models. Three kinds of models are supported: - LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. - - Embedding models: these models generate embeddings to be used for semantic search." + - Embedding models: these models generate embeddings to be used for semantic search. + - Rerank models (Experimental): these models reorder the documents based on their relevance to a query." sidebar_label: Inference title: Inference --- @@ -18,8 +19,9 @@ Inference Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: + This API provides the raw interface to the underlying models. Three kinds of models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + - Rerank models (Experimental): these models reorder the documents based on their relevance to a query. This section contains documentation for all available providers for the **inference** API. diff --git a/docs/static/deprecated-llama-stack-spec.html b/docs/static/deprecated-llama-stack-spec.html index 60a8b9fbd..ef8ec0464 100644 --- a/docs/static/deprecated-llama-stack-spec.html +++ b/docs/static/deprecated-llama-stack-spec.html @@ -13459,7 +13459,7 @@ }, { "name": "Inference", - "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", + "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models (Experimental): these models reorder the documents based on their relevance to a query.", "x-displayName": "Inference" }, { diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index aaa6cd413..74cd3ac56 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -10210,13 +10210,16 @@ tags: embeddings. - This API provides the raw interface to the underlying models. Two kinds of models - are supported: + This API provides the raw interface to the underlying models. Three kinds of + models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + + - Rerank models (Experimental): these models reorder the documents based on + their relevance to a query. x-displayName: Inference - name: Models description: '' diff --git a/docs/static/llama-stack-spec.html b/docs/static/llama-stack-spec.html index 413e4f23e..4a011de66 100644 --- a/docs/static/llama-stack-spec.html +++ b/docs/static/llama-stack-spec.html @@ -6859,7 +6859,8 @@ "type": "string", "enum": [ "llm", - "embedding" + "embedding", + "rerank" ], "title": "ModelType", "description": "Enumeration of supported model types in Llama Stack." @@ -13261,7 +13262,7 @@ }, { "name": "Inference", - "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", + "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models (Experimental): these models reorder the documents based on their relevance to a query.", "x-displayName": "Inference" }, { diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index 93e51de6a..2819cde1f 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -5269,6 +5269,7 @@ components: enum: - llm - embedding + - rerank title: ModelType description: >- Enumeration of supported model types in Llama Stack. @@ -10182,13 +10183,16 @@ tags: embeddings. - This API provides the raw interface to the underlying models. Two kinds of models - are supported: + This API provides the raw interface to the underlying models. Three kinds of + models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + + - Rerank models (Experimental): these models reorder the documents based on + their relevance to a query. x-displayName: Inference - name: Inspect description: >- diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html index 858f20725..29edfa6b4 100644 --- a/docs/static/stainless-llama-stack-spec.html +++ b/docs/static/stainless-llama-stack-spec.html @@ -8531,7 +8531,8 @@ "type": "string", "enum": [ "llm", - "embedding" + "embedding", + "rerank" ], "title": "ModelType", "description": "Enumeration of supported model types in Llama Stack." @@ -17951,7 +17952,7 @@ }, { "name": "Inference", - "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Two kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.", + "description": "Llama Stack Inference API for generating completions, chat completions, and embeddings.\n\nThis API provides the raw interface to the underlying models. Three kinds of models are supported:\n- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions.\n- Embedding models: these models generate embeddings to be used for semantic search.\n- Rerank models (Experimental): these models reorder the documents based on their relevance to a query.", "x-displayName": "Inference" }, { diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 886549dbc..1705ecce0 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -6482,6 +6482,7 @@ components: enum: - llm - embedding + - rerank title: ModelType description: >- Enumeration of supported model types in Llama Stack. @@ -13577,13 +13578,16 @@ tags: embeddings. - This API provides the raw interface to the underlying models. Two kinds of models - are supported: + This API provides the raw interface to the underlying models. Three kinds of + models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + + - Rerank models (Experimental): these models reorder the documents based on + their relevance to a query. x-displayName: Inference - name: Inspect description: >- diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 027246470..6dc16305c 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -1234,9 +1234,10 @@ class Inference(InferenceProvider): Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: + This API provides the raw interface to the underlying models. Three kinds of models are supported: - LLM models: these models generate "raw" and "chat" (conversational) completions. - Embedding models: these models generate embeddings to be used for semantic search. + - Rerank models (Experimental): these models reorder the documents based on their relevance to a query. """ @webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True) diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 10949cb95..5486e3bfd 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -27,10 +27,12 @@ class ModelType(StrEnum): """Enumeration of supported model types in Llama Stack. :cvar llm: Large language model for text generation and completion :cvar embedding: Embedding model for converting text to vector representations + :cvar rerank: Reranking model for reordering documents based on their relevance to a query """ llm = "llm" embedding = "embedding" + rerank = "rerank" @json_schema_type diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index b20ad44ca..09241d836 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -44,9 +44,14 @@ from llama_stack.apis.inference import ( OpenAIEmbeddingsResponse, OpenAIMessageParam, Order, + RerankResponse, StopReason, ToolPromptFormat, ) +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, +) from llama_stack.apis.models import Model, ModelType from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.log import get_logger @@ -182,6 +187,23 @@ class InferenceRouter(Inference): raise ModelTypeError(model_id, model.model_type, expected_model_type) return model + async def rerank( + self, + model: str, + query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + max_num_results: int | None = None, + ) -> RerankResponse: + logger.debug(f"InferenceRouter.rerank: {model}") + model_obj = await self._get_model(model, ModelType.rerank) + provider = await self.routing_table.get_provider_impl(model_obj.identifier) + return await provider.rerank( + model=model_obj.identifier, + query=query, + items=items, + max_num_results=max_num_results, + ) + async def openai_completion( self, params: Annotated[OpenAICompletionRequestWithExtraBody, Body(...)], diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index a9ccc8091..adbe4dcb0 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -78,6 +78,10 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): # Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}} embedding_model_metadata: dict[str, dict[str, int]] = {} + # List of rerank model IDs for this provider + # Can be set by subclasses or instances to provide rerank models + rerank_model_list: list[str] = [] + # Cache of available models keyed by model ID # This is set in list_models() and used in check_model_availability() _model_cache: dict[str, Model] = {} @@ -424,6 +428,13 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): model_type=ModelType.embedding, metadata=metadata, ) + elif provider_model_id in self.rerank_model_list: + model = Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=provider_model_id, + identifier=provider_model_id, + model_type=ModelType.rerank, + ) else: model = Model( provider_id=self.__provider_id__, # type: ignore[attr-defined] diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index 61a1f8f61..277982af6 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -38,6 +38,23 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl): } +class OpenAIMixinWithRerankImpl(OpenAIMixinImpl): + """Test implementation with rerank model list""" + + rerank_model_list: list[str] = ["rerank-model-1", "rerank-model-2"] + + +class OpenAIMixinWithEmbeddingsAndRerankImpl(OpenAIMixinImpl): + """Test implementation with both embedding model metadata and rerank model list""" + + embedding_model_metadata: dict[str, dict[str, int]] = { + "text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192}, + "text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192}, + } + + rerank_model_list: list[str] = ["rerank-model-1", "rerank-model-2"] + + @pytest.fixture def mixin(): """Create a test instance of OpenAIMixin with mocked model_store""" @@ -62,6 +79,20 @@ def mixin_with_embeddings(): return OpenAIMixinWithEmbeddingsImpl(config=config) +@pytest.fixture +def mixin_with_rerank(): + """Create a test instance of OpenAIMixin with rerank model list""" + config = RemoteInferenceProviderConfig() + return OpenAIMixinWithRerankImpl(config=config) + + +@pytest.fixture +def mixin_with_embeddings_and_rerank(): + """Create a test instance of OpenAIMixin with both embedding model metadata and rerank model list""" + config = RemoteInferenceProviderConfig() + return OpenAIMixinWithEmbeddingsAndRerankImpl(config=config) + + @pytest.fixture def mock_models(): """Create multiple mock OpenAI model objects""" @@ -113,6 +144,19 @@ def mock_client_context(): return _mock_client_context +def _assert_models_match_expected(actual_models, expected_models): + """Verify the models match expected attributes. + + Args: + actual_models: List of models to verify + expected_models: Mapping of model identifier to expected attribute values + """ + for identifier, expected_attrs in expected_models.items(): + model = next(m for m in actual_models if m.identifier == identifier) + for attr_name, expected_value in expected_attrs.items(): + assert getattr(model, attr_name) == expected_value + + class TestOpenAIMixinListModels: """Test cases for the list_models method""" @@ -342,21 +386,113 @@ class TestOpenAIMixinEmbeddingModelMetadata: assert result is not None assert len(result) == 2 - # Find the models in the result - embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small") - llm_model = next(m for m in result if m.identifier == "gpt-4") + expected_models = { + "text-embedding-3-small": { + "model_type": ModelType.embedding, + "metadata": {"embedding_dimension": 1536, "context_length": 8192}, + "provider_id": "test-provider", + "provider_resource_id": "text-embedding-3-small", + }, + "gpt-4": { + "model_type": ModelType.llm, + "metadata": {}, + "provider_id": "test-provider", + "provider_resource_id": "gpt-4", + }, + } - # Check embedding model - assert embedding_model.model_type == ModelType.embedding - assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192} - assert embedding_model.provider_id == "test-provider" - assert embedding_model.provider_resource_id == "text-embedding-3-small" + _assert_models_match_expected(result, expected_models) - # Check LLM model - assert llm_model.model_type == ModelType.llm - assert llm_model.metadata == {} # No metadata for LLMs - assert llm_model.provider_id == "test-provider" - assert llm_model.provider_resource_id == "gpt-4" + +class TestOpenAIMixinRerankModelList: + """Test cases for rerank_model_list attribute functionality""" + + async def test_rerank_model_identified(self, mixin_with_rerank, mock_client_context): + """Test that models in rerank_model_list are correctly identified as rerank models""" + # Create mock models: 1 rerank model and 1 LLM + mock_rerank_model = MagicMock(id="rerank-model-1") + mock_llm_model = MagicMock(id="gpt-4") + mock_models = [mock_rerank_model, mock_llm_model] + + mock_client = MagicMock() + + async def mock_models_list(): + for model in mock_models: + yield model + + mock_client.models.list.return_value = mock_models_list() + + with mock_client_context(mixin_with_rerank, mock_client): + result = await mixin_with_rerank.list_models() + + assert result is not None + assert len(result) == 2 + + expected_models = { + "rerank-model-1": { + "model_type": ModelType.rerank, + "metadata": {}, + "provider_id": "test-provider", + "provider_resource_id": "rerank-model-1", + }, + "gpt-4": { + "model_type": ModelType.llm, + "metadata": {}, + "provider_id": "test-provider", + "provider_resource_id": "gpt-4", + }, + } + + _assert_models_match_expected(result, expected_models) + + +class TestOpenAIMixinMixedModelTypes: + """Test cases for mixed model types (LLM, embedding, rerank)""" + + async def test_mixed_model_types_identification(self, mixin_with_embeddings_and_rerank, mock_client_context): + """Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata""" + # Create mock models: 1 embedding, 1 rerank, 1 LLM + mock_embedding_model = MagicMock(id="text-embedding-3-small") + mock_rerank_model = MagicMock(id="rerank-model-1") + mock_llm_model = MagicMock(id="gpt-4") + mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model] + + mock_client = MagicMock() + + async def mock_models_list(): + for model in mock_models: + yield model + + mock_client.models.list.return_value = mock_models_list() + + with mock_client_context(mixin_with_embeddings_and_rerank, mock_client): + result = await mixin_with_embeddings_and_rerank.list_models() + + assert result is not None + assert len(result) == 3 + + expected_models = { + "text-embedding-3-small": { + "model_type": ModelType.embedding, + "metadata": {"embedding_dimension": 1536, "context_length": 8192}, + "provider_id": "test-provider", + "provider_resource_id": "text-embedding-3-small", + }, + "rerank-model-1": { + "model_type": ModelType.rerank, + "metadata": {}, + "provider_id": "test-provider", + "provider_resource_id": "rerank-model-1", + }, + "gpt-4": { + "model_type": ModelType.llm, + "metadata": {}, + "provider_id": "test-provider", + "provider_resource_id": "gpt-4", + }, + } + + _assert_models_match_expected(result, expected_models) class TestOpenAIMixinAllowedModels: